peer.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. package p2p
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net"
  8. "sort"
  9. "sync"
  10. "time"
  11. "github.com/ethereum/go-ethereum/logger"
  12. "github.com/ethereum/go-ethereum/p2p/discover"
  13. "github.com/ethereum/go-ethereum/rlp"
  14. )
  15. const (
  16. baseProtocolVersion = 3
  17. baseProtocolLength = uint64(16)
  18. baseProtocolMaxMsgSize = 10 * 1024 * 1024
  19. disconnectGracePeriod = 2 * time.Second
  20. )
  21. const (
  22. // devp2p message codes
  23. handshakeMsg = 0x00
  24. discMsg = 0x01
  25. pingMsg = 0x02
  26. pongMsg = 0x03
  27. getPeersMsg = 0x04
  28. peersMsg = 0x05
  29. )
  30. // Peer represents a connected remote node.
  31. type Peer struct {
  32. // Peers have all the log methods.
  33. // Use them to display messages related to the peer.
  34. *logger.Logger
  35. rw *conn
  36. running map[string]*protoRW
  37. protoWG sync.WaitGroup
  38. protoErr chan error
  39. closed chan struct{}
  40. disc chan DiscReason
  41. }
  42. // NewPeer returns a peer for testing purposes.
  43. func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
  44. pipe, _ := net.Pipe()
  45. conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
  46. peer := newPeer(conn, nil)
  47. close(peer.closed) // ensures Disconnect doesn't block
  48. return peer
  49. }
  50. // ID returns the node's public key.
  51. func (p *Peer) ID() discover.NodeID {
  52. return p.rw.ID
  53. }
  54. // Name returns the node name that the remote node advertised.
  55. func (p *Peer) Name() string {
  56. return p.rw.Name
  57. }
  58. // Caps returns the capabilities (supported subprotocols) of the remote peer.
  59. func (p *Peer) Caps() []Cap {
  60. // TODO: maybe return copy
  61. return p.rw.Caps
  62. }
  63. // RemoteAddr returns the remote address of the network connection.
  64. func (p *Peer) RemoteAddr() net.Addr {
  65. return p.rw.RemoteAddr()
  66. }
  67. // LocalAddr returns the local address of the network connection.
  68. func (p *Peer) LocalAddr() net.Addr {
  69. return p.rw.LocalAddr()
  70. }
  71. // Disconnect terminates the peer connection with the given reason.
  72. // It returns immediately and does not wait until the connection is closed.
  73. func (p *Peer) Disconnect(reason DiscReason) {
  74. select {
  75. case p.disc <- reason:
  76. case <-p.closed:
  77. }
  78. }
  79. // String implements fmt.Stringer.
  80. func (p *Peer) String() string {
  81. return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
  82. }
  83. func newPeer(conn *conn, protocols []Protocol) *Peer {
  84. logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
  85. p := &Peer{
  86. Logger: logger.NewLogger(logtag),
  87. rw: conn,
  88. running: matchProtocols(protocols, conn.Caps, conn),
  89. disc: make(chan DiscReason),
  90. protoErr: make(chan error),
  91. closed: make(chan struct{}),
  92. }
  93. return p
  94. }
  95. func (p *Peer) run() DiscReason {
  96. var readErr = make(chan error, 1)
  97. defer p.closeProtocols()
  98. defer close(p.closed)
  99. p.startProtocols()
  100. go func() { readErr <- p.readLoop() }()
  101. // Wait for an error or disconnect.
  102. var reason DiscReason
  103. select {
  104. case err := <-readErr:
  105. // We rely on protocols to abort if there is a write error. It
  106. // might be more robust to handle them here as well.
  107. p.DebugDetailf("Read error: %v\n", err)
  108. p.rw.Close()
  109. return DiscNetworkError
  110. case err := <-p.protoErr:
  111. reason = discReasonForError(err)
  112. case reason = <-p.disc:
  113. }
  114. p.politeDisconnect(reason)
  115. // Wait for readLoop. It will end because conn is now closed.
  116. <-readErr
  117. p.Debugf("Disconnected: %v\n", reason)
  118. return reason
  119. }
  120. func (p *Peer) politeDisconnect(reason DiscReason) {
  121. done := make(chan struct{})
  122. go func() {
  123. EncodeMsg(p.rw, discMsg, uint(reason))
  124. // Wait for the other side to close the connection.
  125. // Discard any data that they send until then.
  126. io.Copy(ioutil.Discard, p.rw)
  127. close(done)
  128. }()
  129. select {
  130. case <-done:
  131. case <-time.After(disconnectGracePeriod):
  132. }
  133. p.rw.Close()
  134. }
  135. func (p *Peer) readLoop() error {
  136. for {
  137. msg, err := p.rw.ReadMsg()
  138. if err != nil {
  139. return err
  140. }
  141. if err = p.handle(msg); err != nil {
  142. return err
  143. }
  144. }
  145. return nil
  146. }
  147. func (p *Peer) handle(msg Msg) error {
  148. switch {
  149. case msg.Code == pingMsg:
  150. msg.Discard()
  151. go EncodeMsg(p.rw, pongMsg)
  152. case msg.Code == discMsg:
  153. var reason DiscReason
  154. // no need to discard or for error checking, we'll close the
  155. // connection after this.
  156. rlp.Decode(msg.Payload, &reason)
  157. p.Disconnect(DiscRequested)
  158. return discRequestedError(reason)
  159. case msg.Code < baseProtocolLength:
  160. // ignore other base protocol messages
  161. return msg.Discard()
  162. default:
  163. // it's a subprotocol message
  164. proto, err := p.getProto(msg.Code)
  165. if err != nil {
  166. return fmt.Errorf("msg code out of range: %v", msg.Code)
  167. }
  168. proto.in <- msg
  169. }
  170. return nil
  171. }
  172. // matchProtocols creates structures for matching named subprotocols.
  173. func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
  174. sort.Sort(capsByName(caps))
  175. offset := baseProtocolLength
  176. result := make(map[string]*protoRW)
  177. outer:
  178. for _, cap := range caps {
  179. for _, proto := range protocols {
  180. if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
  181. result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
  182. offset += proto.Length
  183. continue outer
  184. }
  185. }
  186. }
  187. return result
  188. }
  189. func (p *Peer) startProtocols() {
  190. for _, proto := range p.running {
  191. proto := proto
  192. p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
  193. p.protoWG.Add(1)
  194. go func() {
  195. err := proto.Run(p, proto)
  196. if err == nil {
  197. p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
  198. err = errors.New("protocol returned")
  199. } else {
  200. p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
  201. }
  202. select {
  203. case p.protoErr <- err:
  204. case <-p.closed:
  205. }
  206. p.protoWG.Done()
  207. }()
  208. }
  209. }
  210. // getProto finds the protocol responsible for handling
  211. // the given message code.
  212. func (p *Peer) getProto(code uint64) (*protoRW, error) {
  213. for _, proto := range p.running {
  214. if code >= proto.offset && code < proto.offset+proto.Length {
  215. return proto, nil
  216. }
  217. }
  218. return nil, newPeerError(errInvalidMsgCode, "%d", code)
  219. }
  220. func (p *Peer) closeProtocols() {
  221. for _, p := range p.running {
  222. close(p.in)
  223. }
  224. p.protoWG.Wait()
  225. }
  226. // writeProtoMsg sends the given message on behalf of the given named protocol.
  227. // this exists because of Server.Broadcast.
  228. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
  229. proto, ok := p.running[protoName]
  230. if !ok {
  231. return fmt.Errorf("protocol %s not handled by peer", protoName)
  232. }
  233. if msg.Code >= proto.Length {
  234. return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
  235. }
  236. msg.Code += proto.offset
  237. return p.rw.WriteMsg(msg)
  238. }
  239. type protoRW struct {
  240. Protocol
  241. in chan Msg
  242. offset uint64
  243. w MsgWriter
  244. }
  245. func (rw *protoRW) WriteMsg(msg Msg) error {
  246. if msg.Code >= rw.Length {
  247. return newPeerError(errInvalidMsgCode, "not handled")
  248. }
  249. msg.Code += rw.offset
  250. return rw.w.WriteMsg(msg)
  251. }
  252. func (rw *protoRW) ReadMsg() (Msg, error) {
  253. msg, ok := <-rw.in
  254. if !ok {
  255. return msg, io.EOF
  256. }
  257. msg.Code -= rw.offset
  258. return msg, nil
  259. }