peer.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. package p2p
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net"
  8. "sort"
  9. "sync"
  10. "time"
  11. "github.com/ethereum/go-ethereum/logger"
  12. "github.com/ethereum/go-ethereum/p2p/discover"
  13. "github.com/ethereum/go-ethereum/rlp"
  14. )
  15. const (
  16. // maximum amount of time allowed for reading a message
  17. msgReadTimeout = 5 * time.Second
  18. // maximum amount of time allowed for writing a message
  19. msgWriteTimeout = 5 * time.Second
  20. // messages smaller than this many bytes will be read at
  21. // once before passing them to a protocol.
  22. wholePayloadSize = 64 * 1024
  23. disconnectGracePeriod = 2 * time.Second
  24. )
  25. const (
  26. baseProtocolVersion = 2
  27. baseProtocolLength = uint64(16)
  28. baseProtocolMaxMsgSize = 10 * 1024 * 1024
  29. )
  30. const (
  31. // devp2p message codes
  32. handshakeMsg = 0x00
  33. discMsg = 0x01
  34. pingMsg = 0x02
  35. pongMsg = 0x03
  36. getPeersMsg = 0x04
  37. peersMsg = 0x05
  38. )
  39. // handshake is the RLP structure of the protocol handshake.
  40. type handshake struct {
  41. Version uint64
  42. Name string
  43. Caps []Cap
  44. ListenPort uint64
  45. NodeID discover.NodeID
  46. }
  47. // Peer represents a connected remote node.
  48. type Peer struct {
  49. // Peers have all the log methods.
  50. // Use them to display messages related to the peer.
  51. *logger.Logger
  52. infoMu sync.Mutex
  53. name string
  54. caps []Cap
  55. ourID, remoteID *discover.NodeID
  56. ourName string
  57. rw *frameRW
  58. // These fields maintain the running protocols.
  59. protocols []Protocol
  60. runlock sync.RWMutex // protects running
  61. running map[string]*proto
  62. // disables protocol handshake, for testing
  63. noHandshake bool
  64. protoWG sync.WaitGroup
  65. protoErr chan error
  66. closed chan struct{}
  67. disc chan DiscReason
  68. }
  69. // NewPeer returns a peer for testing purposes.
  70. func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
  71. conn, _ := net.Pipe()
  72. peer := newPeer(conn, nil, "", nil, &id)
  73. peer.setHandshakeInfo(name, caps)
  74. close(peer.closed) // ensures Disconnect doesn't block
  75. return peer
  76. }
  77. // ID returns the node's public key.
  78. func (p *Peer) ID() discover.NodeID {
  79. return *p.remoteID
  80. }
  81. // Name returns the node name that the remote node advertised.
  82. func (p *Peer) Name() string {
  83. // this needs a lock because the information is part of the
  84. // protocol handshake.
  85. p.infoMu.Lock()
  86. name := p.name
  87. p.infoMu.Unlock()
  88. return name
  89. }
  90. // Caps returns the capabilities (supported subprotocols) of the remote peer.
  91. func (p *Peer) Caps() []Cap {
  92. // this needs a lock because the information is part of the
  93. // protocol handshake.
  94. p.infoMu.Lock()
  95. caps := p.caps
  96. p.infoMu.Unlock()
  97. return caps
  98. }
  99. // RemoteAddr returns the remote address of the network connection.
  100. func (p *Peer) RemoteAddr() net.Addr {
  101. return p.rw.RemoteAddr()
  102. }
  103. // LocalAddr returns the local address of the network connection.
  104. func (p *Peer) LocalAddr() net.Addr {
  105. return p.rw.LocalAddr()
  106. }
  107. // Disconnect terminates the peer connection with the given reason.
  108. // It returns immediately and does not wait until the connection is closed.
  109. func (p *Peer) Disconnect(reason DiscReason) {
  110. select {
  111. case p.disc <- reason:
  112. case <-p.closed:
  113. }
  114. }
  115. // String implements fmt.Stringer.
  116. func (p *Peer) String() string {
  117. return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
  118. }
  119. func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
  120. logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
  121. return &Peer{
  122. Logger: logger.NewLogger(logtag),
  123. rw: newFrameRW(conn, msgWriteTimeout),
  124. ourID: ourID,
  125. ourName: ourName,
  126. remoteID: remoteID,
  127. protocols: protocols,
  128. running: make(map[string]*proto),
  129. disc: make(chan DiscReason),
  130. protoErr: make(chan error),
  131. closed: make(chan struct{}),
  132. }
  133. }
  134. func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
  135. p.infoMu.Lock()
  136. p.name = name
  137. p.caps = caps
  138. p.infoMu.Unlock()
  139. }
  140. func (p *Peer) run() DiscReason {
  141. var readErr = make(chan error, 1)
  142. defer p.closeProtocols()
  143. defer close(p.closed)
  144. go func() { readErr <- p.readLoop() }()
  145. if !p.noHandshake {
  146. if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
  147. p.DebugDetailf("Protocol handshake error: %v\n", err)
  148. p.rw.Close()
  149. return DiscProtocolError
  150. }
  151. }
  152. // Wait for an error or disconnect.
  153. var reason DiscReason
  154. select {
  155. case err := <-readErr:
  156. // We rely on protocols to abort if there is a write error. It
  157. // might be more robust to handle them here as well.
  158. p.DebugDetailf("Read error: %v\n", err)
  159. p.rw.Close()
  160. return DiscNetworkError
  161. case err := <-p.protoErr:
  162. reason = discReasonForError(err)
  163. case reason = <-p.disc:
  164. }
  165. p.politeDisconnect(reason)
  166. // Wait for readLoop. It will end because conn is now closed.
  167. <-readErr
  168. p.Debugf("Disconnected: %v\n", reason)
  169. return reason
  170. }
  171. func (p *Peer) politeDisconnect(reason DiscReason) {
  172. done := make(chan struct{})
  173. go func() {
  174. EncodeMsg(p.rw, discMsg, uint(reason))
  175. // Wait for the other side to close the connection.
  176. // Discard any data that they send until then.
  177. io.Copy(ioutil.Discard, p.rw)
  178. close(done)
  179. }()
  180. select {
  181. case <-done:
  182. case <-time.After(disconnectGracePeriod):
  183. }
  184. p.rw.Close()
  185. }
  186. func (p *Peer) readLoop() error {
  187. if !p.noHandshake {
  188. if err := readProtocolHandshake(p, p.rw); err != nil {
  189. return err
  190. }
  191. }
  192. for {
  193. msg, err := p.rw.ReadMsg()
  194. if err != nil {
  195. return err
  196. }
  197. if err = p.handle(msg); err != nil {
  198. return err
  199. }
  200. }
  201. return nil
  202. }
  203. func (p *Peer) handle(msg Msg) error {
  204. switch {
  205. case msg.Code == pingMsg:
  206. msg.Discard()
  207. go EncodeMsg(p.rw, pongMsg)
  208. case msg.Code == discMsg:
  209. var reason DiscReason
  210. // no need to discard or for error checking, we'll close the
  211. // connection after this.
  212. rlp.Decode(msg.Payload, &reason)
  213. p.Disconnect(DiscRequested)
  214. return discRequestedError(reason)
  215. case msg.Code < baseProtocolLength:
  216. // ignore other base protocol messages
  217. return msg.Discard()
  218. default:
  219. // it's a subprotocol message
  220. proto, err := p.getProto(msg.Code)
  221. if err != nil {
  222. return fmt.Errorf("msg code out of range: %v", msg.Code)
  223. }
  224. proto.in <- msg
  225. }
  226. return nil
  227. }
  228. func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
  229. // read and handle remote handshake
  230. msg, err := rw.ReadMsg()
  231. if err != nil {
  232. return err
  233. }
  234. if msg.Code != handshakeMsg {
  235. return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
  236. }
  237. if msg.Size > baseProtocolMaxMsgSize {
  238. return newPeerError(errInvalidMsg, "message too big")
  239. }
  240. var hs handshake
  241. if err := msg.Decode(&hs); err != nil {
  242. return err
  243. }
  244. // validate handshake info
  245. if hs.Version != baseProtocolVersion {
  246. return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
  247. baseProtocolVersion, hs.Version)
  248. }
  249. if hs.NodeID == *p.remoteID {
  250. return newPeerError(errPubkeyForbidden, "node ID mismatch")
  251. }
  252. // TODO: remove Caps with empty name
  253. p.setHandshakeInfo(hs.Name, hs.Caps)
  254. p.startSubprotocols(hs.Caps)
  255. return nil
  256. }
  257. func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
  258. var caps []interface{}
  259. for _, proto := range ps {
  260. caps = append(caps, proto.cap())
  261. }
  262. return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
  263. }
  264. // startProtocols starts matching named subprotocols.
  265. func (p *Peer) startSubprotocols(caps []Cap) {
  266. sort.Sort(capsByName(caps))
  267. p.runlock.Lock()
  268. defer p.runlock.Unlock()
  269. offset := baseProtocolLength
  270. outer:
  271. for _, cap := range caps {
  272. for _, proto := range p.protocols {
  273. if proto.Name == cap.Name &&
  274. proto.Version == cap.Version &&
  275. p.running[cap.Name] == nil {
  276. p.running[cap.Name] = p.startProto(offset, proto)
  277. offset += proto.Length
  278. continue outer
  279. }
  280. }
  281. }
  282. }
  283. func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
  284. p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
  285. rw := &proto{
  286. name: impl.Name,
  287. in: make(chan Msg),
  288. offset: offset,
  289. maxcode: impl.Length,
  290. w: p.rw,
  291. }
  292. p.protoWG.Add(1)
  293. go func() {
  294. err := impl.Run(p, rw)
  295. if err == nil {
  296. p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
  297. err = errors.New("protocol returned")
  298. } else {
  299. p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
  300. }
  301. select {
  302. case p.protoErr <- err:
  303. case <-p.closed:
  304. }
  305. p.protoWG.Done()
  306. }()
  307. return rw
  308. }
  309. // getProto finds the protocol responsible for handling
  310. // the given message code.
  311. func (p *Peer) getProto(code uint64) (*proto, error) {
  312. p.runlock.RLock()
  313. defer p.runlock.RUnlock()
  314. for _, proto := range p.running {
  315. if code >= proto.offset && code < proto.offset+proto.maxcode {
  316. return proto, nil
  317. }
  318. }
  319. return nil, newPeerError(errInvalidMsgCode, "%d", code)
  320. }
  321. func (p *Peer) closeProtocols() {
  322. p.runlock.RLock()
  323. for _, p := range p.running {
  324. close(p.in)
  325. }
  326. p.runlock.RUnlock()
  327. p.protoWG.Wait()
  328. }
  329. // writeProtoMsg sends the given message on behalf of the given named protocol.
  330. // this exists because of Server.Broadcast.
  331. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
  332. p.runlock.RLock()
  333. proto, ok := p.running[protoName]
  334. p.runlock.RUnlock()
  335. if !ok {
  336. return fmt.Errorf("protocol %s not handled by peer", protoName)
  337. }
  338. if msg.Code >= proto.maxcode {
  339. return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
  340. }
  341. msg.Code += proto.offset
  342. return p.rw.WriteMsg(msg)
  343. }
  344. type proto struct {
  345. name string
  346. in chan Msg
  347. maxcode, offset uint64
  348. w MsgWriter
  349. }
  350. func (rw *proto) WriteMsg(msg Msg) error {
  351. if msg.Code >= rw.maxcode {
  352. return newPeerError(errInvalidMsgCode, "not handled")
  353. }
  354. msg.Code += rw.offset
  355. return rw.w.WriteMsg(msg)
  356. }
  357. func (rw *proto) ReadMsg() (Msg, error) {
  358. msg, ok := <-rw.in
  359. if !ok {
  360. return msg, io.EOF
  361. }
  362. msg.Code -= rw.offset
  363. return msg, nil
  364. }