peer.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. package p2p
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net"
  8. "sort"
  9. "sync"
  10. "time"
  11. "github.com/ethereum/go-ethereum/logger"
  12. "github.com/ethereum/go-ethereum/p2p/discover"
  13. "github.com/ethereum/go-ethereum/rlp"
  14. )
  15. const (
  16. baseProtocolVersion = 2
  17. baseProtocolLength = uint64(16)
  18. baseProtocolMaxMsgSize = 10 * 1024 * 1024
  19. disconnectGracePeriod = 2 * time.Second
  20. )
  21. const (
  22. // devp2p message codes
  23. handshakeMsg = 0x00
  24. discMsg = 0x01
  25. pingMsg = 0x02
  26. pongMsg = 0x03
  27. getPeersMsg = 0x04
  28. peersMsg = 0x05
  29. )
  30. // handshake is the RLP structure of the protocol handshake.
  31. type handshake struct {
  32. Version uint64
  33. Name string
  34. Caps []Cap
  35. ListenPort uint64
  36. NodeID discover.NodeID
  37. }
  38. // Peer represents a connected remote node.
  39. type Peer struct {
  40. // Peers have all the log methods.
  41. // Use them to display messages related to the peer.
  42. *logger.Logger
  43. infoMu sync.Mutex
  44. name string
  45. caps []Cap
  46. ourID, remoteID *discover.NodeID
  47. ourName string
  48. rw *frameRW
  49. // These fields maintain the running protocols.
  50. protocols []Protocol
  51. runlock sync.RWMutex // protects running
  52. running map[string]*proto
  53. // disables protocol handshake, for testing
  54. noHandshake bool
  55. protoWG sync.WaitGroup
  56. protoErr chan error
  57. closed chan struct{}
  58. disc chan DiscReason
  59. }
  60. // NewPeer returns a peer for testing purposes.
  61. func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
  62. conn, _ := net.Pipe()
  63. peer := newPeer(conn, nil, "", nil, &id)
  64. peer.setHandshakeInfo(name, caps)
  65. close(peer.closed) // ensures Disconnect doesn't block
  66. return peer
  67. }
  68. // ID returns the node's public key.
  69. func (p *Peer) ID() discover.NodeID {
  70. return *p.remoteID
  71. }
  72. // Name returns the node name that the remote node advertised.
  73. func (p *Peer) Name() string {
  74. // this needs a lock because the information is part of the
  75. // protocol handshake.
  76. p.infoMu.Lock()
  77. name := p.name
  78. p.infoMu.Unlock()
  79. return name
  80. }
  81. // Caps returns the capabilities (supported subprotocols) of the remote peer.
  82. func (p *Peer) Caps() []Cap {
  83. // this needs a lock because the information is part of the
  84. // protocol handshake.
  85. p.infoMu.Lock()
  86. caps := p.caps
  87. p.infoMu.Unlock()
  88. return caps
  89. }
  90. // RemoteAddr returns the remote address of the network connection.
  91. func (p *Peer) RemoteAddr() net.Addr {
  92. return p.rw.RemoteAddr()
  93. }
  94. // LocalAddr returns the local address of the network connection.
  95. func (p *Peer) LocalAddr() net.Addr {
  96. return p.rw.LocalAddr()
  97. }
  98. // Disconnect terminates the peer connection with the given reason.
  99. // It returns immediately and does not wait until the connection is closed.
  100. func (p *Peer) Disconnect(reason DiscReason) {
  101. select {
  102. case p.disc <- reason:
  103. case <-p.closed:
  104. }
  105. }
  106. // String implements fmt.Stringer.
  107. func (p *Peer) String() string {
  108. return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
  109. }
  110. func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
  111. logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
  112. return &Peer{
  113. Logger: logger.NewLogger(logtag),
  114. rw: newFrameRW(conn, msgWriteTimeout),
  115. ourID: ourID,
  116. ourName: ourName,
  117. remoteID: remoteID,
  118. protocols: protocols,
  119. running: make(map[string]*proto),
  120. disc: make(chan DiscReason),
  121. protoErr: make(chan error),
  122. closed: make(chan struct{}),
  123. }
  124. }
  125. func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
  126. p.infoMu.Lock()
  127. p.name = name
  128. p.caps = caps
  129. p.infoMu.Unlock()
  130. }
  131. func (p *Peer) run() DiscReason {
  132. var readErr = make(chan error, 1)
  133. defer p.closeProtocols()
  134. defer close(p.closed)
  135. go func() { readErr <- p.readLoop() }()
  136. if !p.noHandshake {
  137. if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
  138. p.DebugDetailf("Protocol handshake error: %v\n", err)
  139. p.rw.Close()
  140. return DiscProtocolError
  141. }
  142. }
  143. // Wait for an error or disconnect.
  144. var reason DiscReason
  145. select {
  146. case err := <-readErr:
  147. // We rely on protocols to abort if there is a write error. It
  148. // might be more robust to handle them here as well.
  149. p.DebugDetailf("Read error: %v\n", err)
  150. p.rw.Close()
  151. return DiscNetworkError
  152. case err := <-p.protoErr:
  153. reason = discReasonForError(err)
  154. case reason = <-p.disc:
  155. }
  156. p.politeDisconnect(reason)
  157. // Wait for readLoop. It will end because conn is now closed.
  158. <-readErr
  159. p.Debugf("Disconnected: %v\n", reason)
  160. return reason
  161. }
  162. func (p *Peer) politeDisconnect(reason DiscReason) {
  163. done := make(chan struct{})
  164. go func() {
  165. EncodeMsg(p.rw, discMsg, uint(reason))
  166. // Wait for the other side to close the connection.
  167. // Discard any data that they send until then.
  168. io.Copy(ioutil.Discard, p.rw)
  169. close(done)
  170. }()
  171. select {
  172. case <-done:
  173. case <-time.After(disconnectGracePeriod):
  174. }
  175. p.rw.Close()
  176. }
  177. func (p *Peer) readLoop() error {
  178. if !p.noHandshake {
  179. if err := readProtocolHandshake(p, p.rw); err != nil {
  180. return err
  181. }
  182. }
  183. for {
  184. msg, err := p.rw.ReadMsg()
  185. if err != nil {
  186. return err
  187. }
  188. if err = p.handle(msg); err != nil {
  189. return err
  190. }
  191. }
  192. return nil
  193. }
  194. func (p *Peer) handle(msg Msg) error {
  195. switch {
  196. case msg.Code == pingMsg:
  197. msg.Discard()
  198. go EncodeMsg(p.rw, pongMsg)
  199. case msg.Code == discMsg:
  200. var reason DiscReason
  201. // no need to discard or for error checking, we'll close the
  202. // connection after this.
  203. rlp.Decode(msg.Payload, &reason)
  204. p.Disconnect(DiscRequested)
  205. return discRequestedError(reason)
  206. case msg.Code < baseProtocolLength:
  207. // ignore other base protocol messages
  208. return msg.Discard()
  209. default:
  210. // it's a subprotocol message
  211. proto, err := p.getProto(msg.Code)
  212. if err != nil {
  213. return fmt.Errorf("msg code out of range: %v", msg.Code)
  214. }
  215. proto.in <- msg
  216. }
  217. return nil
  218. }
  219. func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
  220. // read and handle remote handshake
  221. msg, err := rw.ReadMsg()
  222. if err != nil {
  223. return err
  224. }
  225. if msg.Code != handshakeMsg {
  226. return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
  227. }
  228. if msg.Size > baseProtocolMaxMsgSize {
  229. return newPeerError(errInvalidMsg, "message too big")
  230. }
  231. var hs handshake
  232. if err := msg.Decode(&hs); err != nil {
  233. return err
  234. }
  235. // validate handshake info
  236. if hs.Version != baseProtocolVersion {
  237. return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
  238. baseProtocolVersion, hs.Version)
  239. }
  240. if hs.NodeID == *p.remoteID {
  241. return newPeerError(errPubkeyForbidden, "node ID mismatch")
  242. }
  243. // TODO: remove Caps with empty name
  244. p.setHandshakeInfo(hs.Name, hs.Caps)
  245. p.startSubprotocols(hs.Caps)
  246. return nil
  247. }
  248. func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
  249. var caps []interface{}
  250. for _, proto := range ps {
  251. caps = append(caps, proto.cap())
  252. }
  253. return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
  254. }
  255. // startProtocols starts matching named subprotocols.
  256. func (p *Peer) startSubprotocols(caps []Cap) {
  257. sort.Sort(capsByName(caps))
  258. p.runlock.Lock()
  259. defer p.runlock.Unlock()
  260. offset := baseProtocolLength
  261. outer:
  262. for _, cap := range caps {
  263. for _, proto := range p.protocols {
  264. if proto.Name == cap.Name &&
  265. proto.Version == cap.Version &&
  266. p.running[cap.Name] == nil {
  267. p.running[cap.Name] = p.startProto(offset, proto)
  268. offset += proto.Length
  269. continue outer
  270. }
  271. }
  272. }
  273. }
  274. func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
  275. p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
  276. rw := &proto{
  277. name: impl.Name,
  278. in: make(chan Msg),
  279. offset: offset,
  280. maxcode: impl.Length,
  281. w: p.rw,
  282. }
  283. p.protoWG.Add(1)
  284. go func() {
  285. err := impl.Run(p, rw)
  286. if err == nil {
  287. p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
  288. err = errors.New("protocol returned")
  289. } else {
  290. p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
  291. }
  292. select {
  293. case p.protoErr <- err:
  294. case <-p.closed:
  295. }
  296. p.protoWG.Done()
  297. }()
  298. return rw
  299. }
  300. // getProto finds the protocol responsible for handling
  301. // the given message code.
  302. func (p *Peer) getProto(code uint64) (*proto, error) {
  303. p.runlock.RLock()
  304. defer p.runlock.RUnlock()
  305. for _, proto := range p.running {
  306. if code >= proto.offset && code < proto.offset+proto.maxcode {
  307. return proto, nil
  308. }
  309. }
  310. return nil, newPeerError(errInvalidMsgCode, "%d", code)
  311. }
  312. func (p *Peer) closeProtocols() {
  313. p.runlock.RLock()
  314. for _, p := range p.running {
  315. close(p.in)
  316. }
  317. p.runlock.RUnlock()
  318. p.protoWG.Wait()
  319. }
  320. // writeProtoMsg sends the given message on behalf of the given named protocol.
  321. // this exists because of Server.Broadcast.
  322. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
  323. p.runlock.RLock()
  324. proto, ok := p.running[protoName]
  325. p.runlock.RUnlock()
  326. if !ok {
  327. return fmt.Errorf("protocol %s not handled by peer", protoName)
  328. }
  329. if msg.Code >= proto.maxcode {
  330. return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
  331. }
  332. msg.Code += proto.offset
  333. return p.rw.WriteMsg(msg)
  334. }
  335. type proto struct {
  336. name string
  337. in chan Msg
  338. maxcode, offset uint64
  339. w MsgWriter
  340. }
  341. func (rw *proto) WriteMsg(msg Msg) error {
  342. if msg.Code >= rw.maxcode {
  343. return newPeerError(errInvalidMsgCode, "not handled")
  344. }
  345. msg.Code += rw.offset
  346. return rw.w.WriteMsg(msg)
  347. }
  348. func (rw *proto) ReadMsg() (Msg, error) {
  349. msg, ok := <-rw.in
  350. if !ok {
  351. return msg, io.EOF
  352. }
  353. msg.Code -= rw.offset
  354. return msg, nil
  355. }