peer.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. package p2p
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net"
  8. "sort"
  9. "sync"
  10. "time"
  11. "github.com/ethereum/go-ethereum/logger"
  12. "github.com/ethereum/go-ethereum/p2p/discover"
  13. "github.com/ethereum/go-ethereum/rlp"
  14. )
  15. const (
  16. baseProtocolVersion = 3
  17. baseProtocolLength = uint64(16)
  18. baseProtocolMaxMsgSize = 10 * 1024 * 1024
  19. pingInterval = 15 * time.Second
  20. disconnectGracePeriod = 2 * time.Second
  21. )
  22. const (
  23. // devp2p message codes
  24. handshakeMsg = 0x00
  25. discMsg = 0x01
  26. pingMsg = 0x02
  27. pongMsg = 0x03
  28. getPeersMsg = 0x04
  29. peersMsg = 0x05
  30. )
  31. // Peer represents a connected remote node.
  32. type Peer struct {
  33. // Peers have all the log methods.
  34. // Use them to display messages related to the peer.
  35. *logger.Logger
  36. conn net.Conn
  37. rw *conn
  38. running map[string]*protoRW
  39. protoWG sync.WaitGroup
  40. protoErr chan error
  41. closed chan struct{}
  42. disc chan DiscReason
  43. }
  44. // NewPeer returns a peer for testing purposes.
  45. func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
  46. pipe, _ := net.Pipe()
  47. msgpipe, _ := MsgPipe()
  48. conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
  49. peer := newPeer(pipe, conn, nil)
  50. close(peer.closed) // ensures Disconnect doesn't block
  51. return peer
  52. }
  53. // ID returns the node's public key.
  54. func (p *Peer) ID() discover.NodeID {
  55. return p.rw.ID
  56. }
  57. // Name returns the node name that the remote node advertised.
  58. func (p *Peer) Name() string {
  59. return p.rw.Name
  60. }
  61. // Caps returns the capabilities (supported subprotocols) of the remote peer.
  62. func (p *Peer) Caps() []Cap {
  63. // TODO: maybe return copy
  64. return p.rw.Caps
  65. }
  66. // RemoteAddr returns the remote address of the network connection.
  67. func (p *Peer) RemoteAddr() net.Addr {
  68. return p.conn.RemoteAddr()
  69. }
  70. // LocalAddr returns the local address of the network connection.
  71. func (p *Peer) LocalAddr() net.Addr {
  72. return p.conn.LocalAddr()
  73. }
  74. // Disconnect terminates the peer connection with the given reason.
  75. // It returns immediately and does not wait until the connection is closed.
  76. func (p *Peer) Disconnect(reason DiscReason) {
  77. select {
  78. case p.disc <- reason:
  79. case <-p.closed:
  80. }
  81. }
  82. // String implements fmt.Stringer.
  83. func (p *Peer) String() string {
  84. return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
  85. }
  86. func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
  87. logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr())
  88. p := &Peer{
  89. Logger: logger.NewLogger(logtag),
  90. conn: fd,
  91. rw: conn,
  92. running: matchProtocols(protocols, conn.Caps, conn),
  93. disc: make(chan DiscReason),
  94. protoErr: make(chan error),
  95. closed: make(chan struct{}),
  96. }
  97. return p
  98. }
  99. func (p *Peer) run() DiscReason {
  100. var readErr = make(chan error, 1)
  101. defer p.closeProtocols()
  102. defer close(p.closed)
  103. p.startProtocols()
  104. go func() { readErr <- p.readLoop() }()
  105. ping := time.NewTicker(pingInterval)
  106. defer ping.Stop()
  107. // Wait for an error or disconnect.
  108. var reason DiscReason
  109. loop:
  110. for {
  111. select {
  112. case <-ping.C:
  113. go func() {
  114. if err := EncodeMsg(p.rw, pingMsg, nil); err != nil {
  115. p.protoErr <- err
  116. return
  117. }
  118. }()
  119. case err := <-readErr:
  120. // We rely on protocols to abort if there is a write error. It
  121. // might be more robust to handle them here as well.
  122. p.DebugDetailf("Read error: %v\n", err)
  123. p.conn.Close()
  124. return DiscNetworkError
  125. case err := <-p.protoErr:
  126. reason = discReasonForError(err)
  127. break loop
  128. case reason = <-p.disc:
  129. break loop
  130. }
  131. }
  132. p.politeDisconnect(reason)
  133. // Wait for readLoop. It will end because conn is now closed.
  134. <-readErr
  135. p.Debugf("Disconnected: %v\n", reason)
  136. return reason
  137. }
  138. func (p *Peer) politeDisconnect(reason DiscReason) {
  139. done := make(chan struct{})
  140. go func() {
  141. EncodeMsg(p.rw, discMsg, uint(reason))
  142. // Wait for the other side to close the connection.
  143. // Discard any data that they send until then.
  144. io.Copy(ioutil.Discard, p.conn)
  145. close(done)
  146. }()
  147. select {
  148. case <-done:
  149. case <-time.After(disconnectGracePeriod):
  150. }
  151. p.conn.Close()
  152. }
  153. func (p *Peer) readLoop() error {
  154. for {
  155. p.conn.SetDeadline(time.Now().Add(frameReadTimeout))
  156. msg, err := p.rw.ReadMsg()
  157. if err != nil {
  158. return err
  159. }
  160. if err = p.handle(msg); err != nil {
  161. return err
  162. }
  163. }
  164. return nil
  165. }
  166. func (p *Peer) handle(msg Msg) error {
  167. switch {
  168. case msg.Code == pingMsg:
  169. msg.Discard()
  170. go EncodeMsg(p.rw, pongMsg)
  171. case msg.Code == discMsg:
  172. var reason [1]DiscReason
  173. // no need to discard or for error checking, we'll close the
  174. // connection after this.
  175. rlp.Decode(msg.Payload, &reason)
  176. p.Disconnect(DiscRequested)
  177. return discRequestedError(reason[0])
  178. case msg.Code < baseProtocolLength:
  179. // ignore other base protocol messages
  180. return msg.Discard()
  181. default:
  182. // it's a subprotocol message
  183. proto, err := p.getProto(msg.Code)
  184. if err != nil {
  185. return fmt.Errorf("msg code out of range: %v", msg.Code)
  186. }
  187. proto.in <- msg
  188. }
  189. return nil
  190. }
  191. // matchProtocols creates structures for matching named subprotocols.
  192. func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
  193. sort.Sort(capsByName(caps))
  194. offset := baseProtocolLength
  195. result := make(map[string]*protoRW)
  196. outer:
  197. for _, cap := range caps {
  198. for _, proto := range protocols {
  199. if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
  200. result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
  201. offset += proto.Length
  202. continue outer
  203. }
  204. }
  205. }
  206. return result
  207. }
  208. func (p *Peer) startProtocols() {
  209. for _, proto := range p.running {
  210. proto := proto
  211. p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
  212. p.protoWG.Add(1)
  213. go func() {
  214. err := proto.Run(p, proto)
  215. if err == nil {
  216. p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
  217. err = errors.New("protocol returned")
  218. } else {
  219. p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
  220. }
  221. select {
  222. case p.protoErr <- err:
  223. case <-p.closed:
  224. }
  225. p.protoWG.Done()
  226. }()
  227. }
  228. }
  229. // getProto finds the protocol responsible for handling
  230. // the given message code.
  231. func (p *Peer) getProto(code uint64) (*protoRW, error) {
  232. for _, proto := range p.running {
  233. if code >= proto.offset && code < proto.offset+proto.Length {
  234. return proto, nil
  235. }
  236. }
  237. return nil, newPeerError(errInvalidMsgCode, "%d", code)
  238. }
  239. func (p *Peer) closeProtocols() {
  240. for _, p := range p.running {
  241. close(p.in)
  242. }
  243. p.protoWG.Wait()
  244. }
  245. // writeProtoMsg sends the given message on behalf of the given named protocol.
  246. // this exists because of Server.Broadcast.
  247. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
  248. proto, ok := p.running[protoName]
  249. if !ok {
  250. return fmt.Errorf("protocol %s not handled by peer", protoName)
  251. }
  252. if msg.Code >= proto.Length {
  253. return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
  254. }
  255. msg.Code += proto.offset
  256. return p.rw.WriteMsg(msg)
  257. }
  258. type protoRW struct {
  259. Protocol
  260. in chan Msg
  261. offset uint64
  262. w MsgWriter
  263. }
  264. func (rw *protoRW) WriteMsg(msg Msg) error {
  265. if msg.Code >= rw.Length {
  266. return newPeerError(errInvalidMsgCode, "not handled")
  267. }
  268. msg.Code += rw.offset
  269. return rw.w.WriteMsg(msg)
  270. }
  271. func (rw *protoRW) ReadMsg() (Msg, error) {
  272. msg, ok := <-rw.in
  273. if !ok {
  274. return msg, io.EOF
  275. }
  276. msg.Code -= rw.offset
  277. return msg, nil
  278. }