peer.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. package p2p
  2. import (
  3. "bufio"
  4. "bytes"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net"
  9. "sort"
  10. "sync"
  11. "time"
  12. "github.com/ethereum/go-ethereum/event"
  13. "github.com/ethereum/go-ethereum/logger"
  14. )
  15. // peerAddr is the structure of a peer list element.
  16. // It is also a valid net.Addr.
  17. type peerAddr struct {
  18. IP net.IP
  19. Port uint64
  20. Pubkey []byte // optional
  21. }
  22. func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
  23. n := addr.Network()
  24. if n != "tcp" && n != "tcp4" && n != "tcp6" {
  25. // for testing with non-TCP
  26. return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
  27. }
  28. ta := addr.(*net.TCPAddr)
  29. return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
  30. }
  31. func (d peerAddr) Network() string {
  32. if d.IP.To4() != nil {
  33. return "tcp4"
  34. } else {
  35. return "tcp6"
  36. }
  37. }
  38. func (d peerAddr) String() string {
  39. return fmt.Sprintf("%v:%d", d.IP, d.Port)
  40. }
  41. func (d *peerAddr) RlpData() interface{} {
  42. return []interface{}{string(d.IP), d.Port, d.Pubkey}
  43. }
  44. // Peer represents a remote peer.
  45. type Peer struct {
  46. // Peers have all the log methods.
  47. // Use them to display messages related to the peer.
  48. *logger.Logger
  49. infolock sync.Mutex
  50. identity ClientIdentity
  51. caps []Cap
  52. listenAddr *peerAddr // what remote peer is listening on
  53. dialAddr *peerAddr // non-nil if dialing
  54. // The mutex protects the connection
  55. // so only one protocol can write at a time.
  56. writeMu sync.Mutex
  57. conn net.Conn
  58. bufconn *bufio.ReadWriter
  59. // These fields maintain the running protocols.
  60. protocols []Protocol
  61. runBaseProtocol bool // for testing
  62. runlock sync.RWMutex // protects running
  63. running map[string]*proto
  64. protoWG sync.WaitGroup
  65. protoErr chan error
  66. closed chan struct{}
  67. disc chan DiscReason
  68. activity event.TypeMux // for activity events
  69. slot int // index into Server peer list
  70. // These fields are kept so base protocol can access them.
  71. // TODO: this should be one or more interfaces
  72. ourID ClientIdentity // client id of the Server
  73. ourListenAddr *peerAddr // listen addr of Server, nil if not listening
  74. newPeerAddr chan<- *peerAddr // tell server about received peers
  75. otherPeers func() []*Peer // should return the list of all peers
  76. pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
  77. }
  78. // NewPeer returns a peer for testing purposes.
  79. func NewPeer(id ClientIdentity, caps []Cap) *Peer {
  80. conn, _ := net.Pipe()
  81. peer := newPeer(conn, nil, nil)
  82. peer.setHandshakeInfo(id, nil, caps)
  83. close(peer.closed)
  84. return peer
  85. }
  86. func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
  87. p := newPeer(conn, server.Protocols, dialAddr)
  88. p.ourID = server.Identity
  89. p.newPeerAddr = server.peerConnect
  90. p.otherPeers = server.Peers
  91. p.pubkeyHook = server.verifyPeer
  92. p.runBaseProtocol = true
  93. // laddr can be updated concurrently by NAT traversal.
  94. // newServerPeer must be called with the server lock held.
  95. if server.laddr != nil {
  96. p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
  97. }
  98. return p
  99. }
  100. func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
  101. p := &Peer{
  102. Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
  103. conn: conn,
  104. dialAddr: dialAddr,
  105. bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
  106. protocols: protocols,
  107. running: make(map[string]*proto),
  108. disc: make(chan DiscReason),
  109. protoErr: make(chan error),
  110. closed: make(chan struct{}),
  111. }
  112. return p
  113. }
  114. // Identity returns the client identity of the remote peer. The
  115. // identity can be nil if the peer has not yet completed the
  116. // handshake.
  117. func (p *Peer) Identity() ClientIdentity {
  118. p.infolock.Lock()
  119. defer p.infolock.Unlock()
  120. return p.identity
  121. }
  122. // Caps returns the capabilities (supported subprotocols) of the remote peer.
  123. func (p *Peer) Caps() []Cap {
  124. p.infolock.Lock()
  125. defer p.infolock.Unlock()
  126. return p.caps
  127. }
  128. func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
  129. p.infolock.Lock()
  130. p.identity = id
  131. p.listenAddr = laddr
  132. p.caps = caps
  133. p.infolock.Unlock()
  134. }
  135. // RemoteAddr returns the remote address of the network connection.
  136. func (p *Peer) RemoteAddr() net.Addr {
  137. return p.conn.RemoteAddr()
  138. }
  139. // LocalAddr returns the local address of the network connection.
  140. func (p *Peer) LocalAddr() net.Addr {
  141. return p.conn.LocalAddr()
  142. }
  143. // Disconnect terminates the peer connection with the given reason.
  144. // It returns immediately and does not wait until the connection is closed.
  145. func (p *Peer) Disconnect(reason DiscReason) {
  146. select {
  147. case p.disc <- reason:
  148. case <-p.closed:
  149. }
  150. }
  151. // String implements fmt.Stringer.
  152. func (p *Peer) String() string {
  153. kind := "inbound"
  154. p.infolock.Lock()
  155. if p.dialAddr != nil {
  156. kind = "outbound"
  157. }
  158. p.infolock.Unlock()
  159. return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
  160. }
  161. const (
  162. // maximum amount of time allowed for reading a message
  163. msgReadTimeout = 5 * time.Second
  164. // maximum amount of time allowed for writing a message
  165. msgWriteTimeout = 5 * time.Second
  166. // messages smaller than this many bytes will be read at
  167. // once before passing them to a protocol.
  168. wholePayloadSize = 64 * 1024
  169. )
  170. var (
  171. inactivityTimeout = 2 * time.Second
  172. disconnectGracePeriod = 2 * time.Second
  173. )
  174. func (p *Peer) loop() (reason DiscReason, err error) {
  175. defer p.activity.Stop()
  176. defer p.closeProtocols()
  177. defer close(p.closed)
  178. defer p.conn.Close()
  179. // read loop
  180. readMsg := make(chan Msg)
  181. readErr := make(chan error)
  182. readNext := make(chan bool, 1)
  183. protoDone := make(chan struct{}, 1)
  184. go p.readLoop(readMsg, readErr, readNext)
  185. readNext <- true
  186. if p.runBaseProtocol {
  187. p.startBaseProtocol()
  188. }
  189. loop:
  190. for {
  191. select {
  192. case msg := <-readMsg:
  193. // a new message has arrived.
  194. var wait bool
  195. if wait, err = p.dispatch(msg, protoDone); err != nil {
  196. p.Errorf("msg dispatch error: %v\n", err)
  197. reason = discReasonForError(err)
  198. break loop
  199. }
  200. if !wait {
  201. // Msg has already been read completely, continue with next message.
  202. readNext <- true
  203. }
  204. p.activity.Post(time.Now())
  205. case <-protoDone:
  206. // protocol has consumed the message payload,
  207. // we can continue reading from the socket.
  208. readNext <- true
  209. case err := <-readErr:
  210. // read failed. there is no need to run the
  211. // polite disconnect sequence because the connection
  212. // is probably dead anyway.
  213. // TODO: handle write errors as well
  214. return DiscNetworkError, err
  215. case err = <-p.protoErr:
  216. reason = discReasonForError(err)
  217. break loop
  218. case reason = <-p.disc:
  219. break loop
  220. }
  221. }
  222. // wait for read loop to return.
  223. close(readNext)
  224. <-readErr
  225. // tell the remote end to disconnect
  226. done := make(chan struct{})
  227. go func() {
  228. p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
  229. p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
  230. io.Copy(ioutil.Discard, p.conn)
  231. close(done)
  232. }()
  233. select {
  234. case <-done:
  235. case <-time.After(disconnectGracePeriod):
  236. }
  237. return reason, err
  238. }
  239. func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
  240. for _ = range unblock {
  241. p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
  242. if msg, err := readMsg(p.bufconn); err != nil {
  243. errc <- err
  244. } else {
  245. msgc <- msg
  246. }
  247. }
  248. close(errc)
  249. }
  250. func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
  251. proto, err := p.getProto(msg.Code)
  252. if err != nil {
  253. return false, err
  254. }
  255. if msg.Size <= wholePayloadSize {
  256. // optimization: msg is small enough, read all
  257. // of it and move on to the next message
  258. buf, err := ioutil.ReadAll(msg.Payload)
  259. if err != nil {
  260. return false, err
  261. }
  262. msg.Payload = bytes.NewReader(buf)
  263. proto.in <- msg
  264. } else {
  265. wait = true
  266. pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
  267. msg.Payload = pr
  268. proto.in <- msg
  269. }
  270. return wait, nil
  271. }
  272. func (p *Peer) startBaseProtocol() {
  273. p.runlock.Lock()
  274. defer p.runlock.Unlock()
  275. p.running[""] = p.startProto(0, Protocol{
  276. Length: baseProtocolLength,
  277. Run: runBaseProtocol,
  278. })
  279. }
  280. // startProtocols starts matching named subprotocols.
  281. func (p *Peer) startSubprotocols(caps []Cap) {
  282. sort.Sort(capsByName(caps))
  283. p.runlock.Lock()
  284. defer p.runlock.Unlock()
  285. offset := baseProtocolLength
  286. outer:
  287. for _, cap := range caps {
  288. for _, proto := range p.protocols {
  289. if proto.Name == cap.Name &&
  290. proto.Version == cap.Version &&
  291. p.running[cap.Name] == nil {
  292. p.running[cap.Name] = p.startProto(offset, proto)
  293. offset += proto.Length
  294. continue outer
  295. }
  296. }
  297. }
  298. }
  299. func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
  300. rw := &proto{
  301. in: make(chan Msg),
  302. offset: offset,
  303. maxcode: impl.Length,
  304. peer: p,
  305. }
  306. p.protoWG.Add(1)
  307. go func() {
  308. err := impl.Run(p, rw)
  309. if err == nil {
  310. p.Infof("protocol %q returned", impl.Name)
  311. err = newPeerError(errMisc, "protocol returned")
  312. } else {
  313. p.Errorf("protocol %q error: %v\n", impl.Name, err)
  314. }
  315. select {
  316. case p.protoErr <- err:
  317. case <-p.closed:
  318. }
  319. p.protoWG.Done()
  320. }()
  321. return rw
  322. }
  323. // getProto finds the protocol responsible for handling
  324. // the given message code.
  325. func (p *Peer) getProto(code uint64) (*proto, error) {
  326. p.runlock.RLock()
  327. defer p.runlock.RUnlock()
  328. for _, proto := range p.running {
  329. if code >= proto.offset && code < proto.offset+proto.maxcode {
  330. return proto, nil
  331. }
  332. }
  333. return nil, newPeerError(errInvalidMsgCode, "%d", code)
  334. }
  335. func (p *Peer) closeProtocols() {
  336. p.runlock.RLock()
  337. for _, p := range p.running {
  338. close(p.in)
  339. }
  340. p.runlock.RUnlock()
  341. p.protoWG.Wait()
  342. }
  343. // writeProtoMsg sends the given message on behalf of the given named protocol.
  344. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
  345. p.runlock.RLock()
  346. proto, ok := p.running[protoName]
  347. p.runlock.RUnlock()
  348. if !ok {
  349. return fmt.Errorf("protocol %s not handled by peer", protoName)
  350. }
  351. if msg.Code >= proto.maxcode {
  352. return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
  353. }
  354. msg.Code += proto.offset
  355. return p.writeMsg(msg, msgWriteTimeout)
  356. }
  357. // writeMsg writes a message to the connection.
  358. func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
  359. p.writeMu.Lock()
  360. defer p.writeMu.Unlock()
  361. p.conn.SetWriteDeadline(time.Now().Add(timeout))
  362. if err := writeMsg(p.bufconn, msg); err != nil {
  363. return newPeerError(errWrite, "%v", err)
  364. }
  365. return p.bufconn.Flush()
  366. }
  367. type proto struct {
  368. name string
  369. in chan Msg
  370. maxcode, offset uint64
  371. peer *Peer
  372. }
  373. func (rw *proto) WriteMsg(msg Msg) error {
  374. if msg.Code >= rw.maxcode {
  375. return newPeerError(errInvalidMsgCode, "not handled")
  376. }
  377. msg.Code += rw.offset
  378. return rw.peer.writeMsg(msg, msgWriteTimeout)
  379. }
  380. func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
  381. return rw.WriteMsg(NewMsg(code, data...))
  382. }
  383. func (rw *proto) ReadMsg() (Msg, error) {
  384. msg, ok := <-rw.in
  385. if !ok {
  386. return msg, io.EOF
  387. }
  388. msg.Code -= rw.offset
  389. return msg, nil
  390. }
  391. // eofSignal wraps a reader with eof signaling. the eof channel is
  392. // closed when the wrapped reader returns an error or when count bytes
  393. // have been read.
  394. //
  395. type eofSignal struct {
  396. wrapped io.Reader
  397. count int64
  398. eof chan<- struct{}
  399. }
  400. // note: when using eofSignal to detect whether a message payload
  401. // has been read, Read might not be called for zero sized messages.
  402. func (r *eofSignal) Read(buf []byte) (int, error) {
  403. n, err := r.wrapped.Read(buf)
  404. r.count -= int64(n)
  405. if (err != nil || r.count <= 0) && r.eof != nil {
  406. r.eof <- struct{}{} // tell Peer that msg has been consumed
  407. r.eof = nil
  408. }
  409. return n, err
  410. }