messenger.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. package p2p
  2. import (
  3. "bufio"
  4. "bytes"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net"
  9. "sync"
  10. "time"
  11. )
  12. type Handlers map[string]Protocol
  13. type proto struct {
  14. in chan Msg
  15. maxcode, offset MsgCode
  16. messenger *messenger
  17. }
  18. func (rw *proto) WriteMsg(msg Msg) error {
  19. if msg.Code >= rw.maxcode {
  20. return NewPeerError(InvalidMsgCode, "not handled")
  21. }
  22. msg.Code += rw.offset
  23. return rw.messenger.writeMsg(msg)
  24. }
  25. func (rw *proto) ReadMsg() (Msg, error) {
  26. msg, ok := <-rw.in
  27. if !ok {
  28. return msg, io.EOF
  29. }
  30. msg.Code -= rw.offset
  31. return msg, nil
  32. }
  33. // eofSignal wraps a reader with eof signaling.
  34. // the eof channel is closed when the wrapped reader
  35. // reaches EOF.
  36. type eofSignal struct {
  37. wrapped io.Reader
  38. eof chan struct{}
  39. }
  40. func (r *eofSignal) Read(buf []byte) (int, error) {
  41. n, err := r.wrapped.Read(buf)
  42. if err != nil {
  43. close(r.eof) // tell messenger that msg has been consumed
  44. }
  45. return n, err
  46. }
  47. // messenger represents a message-oriented peer connection.
  48. // It keeps track of the set of protocols understood
  49. // by the remote peer.
  50. type messenger struct {
  51. peer *Peer
  52. handlers Handlers
  53. // the mutex protects the connection
  54. // so only one protocol can write at a time.
  55. writeMu sync.Mutex
  56. conn net.Conn
  57. bufconn *bufio.ReadWriter
  58. protocolLock sync.RWMutex
  59. protocols map[string]*proto
  60. offsets map[MsgCode]*proto
  61. protoWG sync.WaitGroup
  62. err chan error
  63. pulse chan bool
  64. }
  65. func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
  66. return &messenger{
  67. conn: conn,
  68. bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
  69. peer: peer,
  70. handlers: handlers,
  71. protocols: make(map[string]*proto),
  72. err: errchan,
  73. pulse: make(chan bool, 1),
  74. }
  75. }
  76. func (m *messenger) Start() {
  77. m.protocols[""] = m.startProto(0, "", &baseProtocol{})
  78. go m.readLoop()
  79. }
  80. func (m *messenger) Stop() {
  81. m.conn.Close()
  82. m.protoWG.Wait()
  83. }
  84. const (
  85. // maximum amount of time allowed for reading a message
  86. msgReadTimeout = 5 * time.Second
  87. // messages smaller than this many bytes will be read at
  88. // once before passing them to a protocol.
  89. wholePayloadSize = 64 * 1024
  90. )
  91. func (m *messenger) readLoop() {
  92. defer m.closeProtocols()
  93. for {
  94. m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
  95. msg, err := readMsg(m.bufconn)
  96. if err != nil {
  97. m.err <- err
  98. return
  99. }
  100. // send ping to heartbeat channel signalling time of last message
  101. m.pulse <- true
  102. proto, err := m.getProto(msg.Code)
  103. if err != nil {
  104. m.err <- err
  105. return
  106. }
  107. if msg.Size <= wholePayloadSize {
  108. // optimization: msg is small enough, read all
  109. // of it and move on to the next message
  110. buf, err := ioutil.ReadAll(msg.Payload)
  111. if err != nil {
  112. m.err <- err
  113. return
  114. }
  115. msg.Payload = bytes.NewReader(buf)
  116. proto.in <- msg
  117. } else {
  118. pr := &eofSignal{msg.Payload, make(chan struct{})}
  119. msg.Payload = pr
  120. proto.in <- msg
  121. <-pr.eof
  122. }
  123. }
  124. }
  125. func (m *messenger) closeProtocols() {
  126. m.protocolLock.RLock()
  127. for _, p := range m.protocols {
  128. close(p.in)
  129. }
  130. m.protocolLock.RUnlock()
  131. }
  132. func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
  133. proto := &proto{
  134. in: make(chan Msg),
  135. offset: offset,
  136. maxcode: impl.Offset(),
  137. messenger: m,
  138. }
  139. m.protoWG.Add(1)
  140. go func() {
  141. if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
  142. logger.Errorf("protocol %q error: %v\n", name, err)
  143. m.err <- err
  144. }
  145. m.protoWG.Done()
  146. }()
  147. return proto
  148. }
  149. // getProto finds the protocol responsible for handling
  150. // the given message code.
  151. func (m *messenger) getProto(code MsgCode) (*proto, error) {
  152. m.protocolLock.RLock()
  153. defer m.protocolLock.RUnlock()
  154. for _, proto := range m.protocols {
  155. if code >= proto.offset && code < proto.offset+proto.maxcode {
  156. return proto, nil
  157. }
  158. }
  159. return nil, NewPeerError(InvalidMsgCode, "%d", code)
  160. }
  161. // setProtocols starts all subprotocols shared with the
  162. // remote peer. the protocols must be sorted alphabetically.
  163. func (m *messenger) setRemoteProtocols(protocols []string) {
  164. m.protocolLock.Lock()
  165. defer m.protocolLock.Unlock()
  166. offset := baseProtocolOffset
  167. for _, name := range protocols {
  168. inst, ok := m.handlers[name]
  169. if !ok {
  170. continue // not handled
  171. }
  172. m.protocols[name] = m.startProto(offset, name, inst)
  173. offset += inst.Offset()
  174. }
  175. }
  176. // writeProtoMsg sends the given message on behalf of the given named protocol.
  177. func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
  178. m.protocolLock.RLock()
  179. proto, ok := m.protocols[protoName]
  180. m.protocolLock.RUnlock()
  181. if !ok {
  182. return fmt.Errorf("protocol %s not handled by peer", protoName)
  183. }
  184. if msg.Code >= proto.maxcode {
  185. return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
  186. }
  187. msg.Code += proto.offset
  188. return m.writeMsg(msg)
  189. }
  190. // writeMsg writes a message to the connection.
  191. func (m *messenger) writeMsg(msg Msg) error {
  192. m.writeMu.Lock()
  193. defer m.writeMu.Unlock()
  194. if err := writeMsg(m.bufconn, msg); err != nil {
  195. return err
  196. }
  197. return m.bufconn.Flush()
  198. }