peer.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package p2p
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net"
  7. "sort"
  8. "sync"
  9. "time"
  10. "github.com/ethereum/go-ethereum/logger"
  11. "github.com/ethereum/go-ethereum/logger/glog"
  12. "github.com/ethereum/go-ethereum/p2p/discover"
  13. "github.com/ethereum/go-ethereum/rlp"
  14. )
  15. const (
  16. baseProtocolVersion = 4
  17. baseProtocolLength = uint64(16)
  18. baseProtocolMaxMsgSize = 10 * 1024 * 1024
  19. pingInterval = 15 * time.Second
  20. )
  21. const (
  22. // devp2p message codes
  23. handshakeMsg = 0x00
  24. discMsg = 0x01
  25. pingMsg = 0x02
  26. pongMsg = 0x03
  27. getPeersMsg = 0x04
  28. peersMsg = 0x05
  29. )
  30. // Peer represents a connected remote node.
  31. type Peer struct {
  32. conn net.Conn
  33. rw *conn
  34. running map[string]*protoRW
  35. wg sync.WaitGroup
  36. protoErr chan error
  37. closed chan struct{}
  38. disc chan DiscReason
  39. }
  40. // NewPeer returns a peer for testing purposes.
  41. func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
  42. pipe, _ := net.Pipe()
  43. msgpipe, _ := MsgPipe()
  44. conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
  45. peer := newPeer(pipe, conn, nil)
  46. close(peer.closed) // ensures Disconnect doesn't block
  47. return peer
  48. }
  49. // ID returns the node's public key.
  50. func (p *Peer) ID() discover.NodeID {
  51. return p.rw.ID
  52. }
  53. // Name returns the node name that the remote node advertised.
  54. func (p *Peer) Name() string {
  55. return p.rw.Name
  56. }
  57. // Caps returns the capabilities (supported subprotocols) of the remote peer.
  58. func (p *Peer) Caps() []Cap {
  59. // TODO: maybe return copy
  60. return p.rw.Caps
  61. }
  62. // RemoteAddr returns the remote address of the network connection.
  63. func (p *Peer) RemoteAddr() net.Addr {
  64. return p.conn.RemoteAddr()
  65. }
  66. // LocalAddr returns the local address of the network connection.
  67. func (p *Peer) LocalAddr() net.Addr {
  68. return p.conn.LocalAddr()
  69. }
  70. // Disconnect terminates the peer connection with the given reason.
  71. // It returns immediately and does not wait until the connection is closed.
  72. func (p *Peer) Disconnect(reason DiscReason) {
  73. select {
  74. case p.disc <- reason:
  75. case <-p.closed:
  76. }
  77. }
  78. // String implements fmt.Stringer.
  79. func (p *Peer) String() string {
  80. return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
  81. }
  82. func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
  83. protomap := matchProtocols(protocols, conn.Caps, conn)
  84. p := &Peer{
  85. conn: fd,
  86. rw: conn,
  87. running: protomap,
  88. disc: make(chan DiscReason),
  89. protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop
  90. closed: make(chan struct{}),
  91. }
  92. return p
  93. }
  94. func (p *Peer) run() DiscReason {
  95. readErr := make(chan error, 1)
  96. p.wg.Add(2)
  97. go p.readLoop(readErr)
  98. go p.pingLoop()
  99. p.startProtocols()
  100. // Wait for an error or disconnect.
  101. var reason DiscReason
  102. select {
  103. case err := <-readErr:
  104. if r, ok := err.(DiscReason); ok {
  105. reason = r
  106. } else {
  107. // Note: We rely on protocols to abort if there is a write
  108. // error. It might be more robust to handle them here as well.
  109. glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
  110. reason = DiscNetworkError
  111. }
  112. case err := <-p.protoErr:
  113. reason = discReasonForError(err)
  114. case reason = <-p.disc:
  115. p.politeDisconnect(reason)
  116. reason = DiscRequested
  117. }
  118. close(p.closed)
  119. p.wg.Wait()
  120. glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
  121. return reason
  122. }
  123. func (p *Peer) politeDisconnect(reason DiscReason) {
  124. if reason != DiscNetworkError {
  125. SendItems(p.rw, discMsg, uint(reason))
  126. }
  127. p.conn.Close()
  128. }
  129. func (p *Peer) pingLoop() {
  130. ping := time.NewTicker(pingInterval)
  131. defer p.wg.Done()
  132. defer ping.Stop()
  133. for {
  134. select {
  135. case <-ping.C:
  136. if err := SendItems(p.rw, pingMsg); err != nil {
  137. p.protoErr <- err
  138. return
  139. }
  140. case <-p.closed:
  141. return
  142. }
  143. }
  144. }
  145. func (p *Peer) readLoop(errc chan<- error) {
  146. defer p.wg.Done()
  147. for {
  148. msg, err := p.rw.ReadMsg()
  149. if err != nil {
  150. errc <- err
  151. return
  152. }
  153. msg.ReceivedAt = time.Now()
  154. if err = p.handle(msg); err != nil {
  155. errc <- err
  156. return
  157. }
  158. }
  159. }
  160. func (p *Peer) handle(msg Msg) error {
  161. switch {
  162. case msg.Code == pingMsg:
  163. msg.Discard()
  164. go SendItems(p.rw, pongMsg)
  165. case msg.Code == discMsg:
  166. var reason [1]DiscReason
  167. // This is the last message. We don't need to discard or
  168. // check errors because, the connection will be closed after it.
  169. rlp.Decode(msg.Payload, &reason)
  170. glog.V(logger.Debug).Infof("%v: Disconnect Requested: %v\n", p, reason[0])
  171. return reason[0]
  172. case msg.Code < baseProtocolLength:
  173. // ignore other base protocol messages
  174. return msg.Discard()
  175. default:
  176. // it's a subprotocol message
  177. proto, err := p.getProto(msg.Code)
  178. if err != nil {
  179. return fmt.Errorf("msg code out of range: %v", msg.Code)
  180. }
  181. select {
  182. case proto.in <- msg:
  183. return nil
  184. case <-p.closed:
  185. return io.EOF
  186. }
  187. }
  188. return nil
  189. }
  190. func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
  191. n := 0
  192. for _, cap := range caps {
  193. for _, proto := range protocols {
  194. if proto.Name == cap.Name && proto.Version == cap.Version {
  195. n++
  196. }
  197. }
  198. }
  199. return n
  200. }
  201. // matchProtocols creates structures for matching named subprotocols.
  202. func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
  203. sort.Sort(capsByName(caps))
  204. offset := baseProtocolLength
  205. result := make(map[string]*protoRW)
  206. outer:
  207. for _, cap := range caps {
  208. for _, proto := range protocols {
  209. if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
  210. result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
  211. offset += proto.Length
  212. continue outer
  213. }
  214. }
  215. }
  216. return result
  217. }
  218. func (p *Peer) startProtocols() {
  219. p.wg.Add(len(p.running))
  220. for _, proto := range p.running {
  221. proto := proto
  222. proto.closed = p.closed
  223. glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
  224. go func() {
  225. err := proto.Run(p, proto)
  226. if err == nil {
  227. glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version)
  228. err = errors.New("protocol returned")
  229. } else if err != io.EOF {
  230. glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: \n", p, proto.Name, proto.Version, err)
  231. }
  232. p.protoErr <- err
  233. p.wg.Done()
  234. }()
  235. }
  236. }
  237. // getProto finds the protocol responsible for handling
  238. // the given message code.
  239. func (p *Peer) getProto(code uint64) (*protoRW, error) {
  240. for _, proto := range p.running {
  241. if code >= proto.offset && code < proto.offset+proto.Length {
  242. return proto, nil
  243. }
  244. }
  245. return nil, newPeerError(errInvalidMsgCode, "%d", code)
  246. }
  247. type protoRW struct {
  248. Protocol
  249. in chan Msg
  250. closed <-chan struct{}
  251. offset uint64
  252. w MsgWriter
  253. }
  254. func (rw *protoRW) WriteMsg(msg Msg) error {
  255. if msg.Code >= rw.Length {
  256. return newPeerError(errInvalidMsgCode, "not handled")
  257. }
  258. msg.Code += rw.offset
  259. return rw.w.WriteMsg(msg)
  260. }
  261. func (rw *protoRW) ReadMsg() (Msg, error) {
  262. select {
  263. case msg := <-rw.in:
  264. msg.Code -= rw.offset
  265. return msg, nil
  266. case <-rw.closed:
  267. return Msg{}, io.EOF
  268. }
  269. }