| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- package p2p
- import (
- "bufio"
- "bytes"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "sync"
- "time"
- )
- type Handlers map[string]Protocol
- type proto struct {
- in chan Msg
- maxcode, offset MsgCode
- messenger *messenger
- }
- func (rw *proto) WriteMsg(msg Msg) error {
- if msg.Code >= rw.maxcode {
- return NewPeerError(InvalidMsgCode, "not handled")
- }
- msg.Code += rw.offset
- return rw.messenger.writeMsg(msg)
- }
- func (rw *proto) ReadMsg() (Msg, error) {
- msg, ok := <-rw.in
- if !ok {
- return msg, io.EOF
- }
- msg.Code -= rw.offset
- return msg, nil
- }
- // eofSignal wraps a reader with eof signaling.
- // the eof channel is closed when the wrapped reader
- // reaches EOF.
- type eofSignal struct {
- wrapped io.Reader
- eof chan struct{}
- }
- func (r *eofSignal) Read(buf []byte) (int, error) {
- n, err := r.wrapped.Read(buf)
- if err != nil {
- close(r.eof) // tell messenger that msg has been consumed
- }
- return n, err
- }
- // messenger represents a message-oriented peer connection.
- // It keeps track of the set of protocols understood
- // by the remote peer.
- type messenger struct {
- peer *Peer
- handlers Handlers
- // the mutex protects the connection
- // so only one protocol can write at a time.
- writeMu sync.Mutex
- conn net.Conn
- bufconn *bufio.ReadWriter
- protocolLock sync.RWMutex
- protocols map[string]*proto
- offsets map[MsgCode]*proto
- protoWG sync.WaitGroup
- err chan error
- pulse chan bool
- }
- func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
- return &messenger{
- conn: conn,
- bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
- peer: peer,
- handlers: handlers,
- protocols: make(map[string]*proto),
- err: errchan,
- pulse: make(chan bool, 1),
- }
- }
- func (m *messenger) Start() {
- m.protocols[""] = m.startProto(0, "", &baseProtocol{})
- go m.readLoop()
- }
- func (m *messenger) Stop() {
- m.conn.Close()
- m.protoWG.Wait()
- }
- const (
- // maximum amount of time allowed for reading a message
- msgReadTimeout = 5 * time.Second
- // messages smaller than this many bytes will be read at
- // once before passing them to a protocol.
- wholePayloadSize = 64 * 1024
- )
- func (m *messenger) readLoop() {
- defer m.closeProtocols()
- for {
- m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
- msg, err := readMsg(m.bufconn)
- if err != nil {
- m.err <- err
- return
- }
- // send ping to heartbeat channel signalling time of last message
- m.pulse <- true
- proto, err := m.getProto(msg.Code)
- if err != nil {
- m.err <- err
- return
- }
- if msg.Size <= wholePayloadSize {
- // optimization: msg is small enough, read all
- // of it and move on to the next message
- buf, err := ioutil.ReadAll(msg.Payload)
- if err != nil {
- m.err <- err
- return
- }
- msg.Payload = bytes.NewReader(buf)
- proto.in <- msg
- } else {
- pr := &eofSignal{msg.Payload, make(chan struct{})}
- msg.Payload = pr
- proto.in <- msg
- <-pr.eof
- }
- }
- }
- func (m *messenger) closeProtocols() {
- m.protocolLock.RLock()
- for _, p := range m.protocols {
- close(p.in)
- }
- m.protocolLock.RUnlock()
- }
- func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
- proto := &proto{
- in: make(chan Msg),
- offset: offset,
- maxcode: impl.Offset(),
- messenger: m,
- }
- m.protoWG.Add(1)
- go func() {
- if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
- logger.Errorf("protocol %q error: %v\n", name, err)
- m.err <- err
- }
- m.protoWG.Done()
- }()
- return proto
- }
- // getProto finds the protocol responsible for handling
- // the given message code.
- func (m *messenger) getProto(code MsgCode) (*proto, error) {
- m.protocolLock.RLock()
- defer m.protocolLock.RUnlock()
- for _, proto := range m.protocols {
- if code >= proto.offset && code < proto.offset+proto.maxcode {
- return proto, nil
- }
- }
- return nil, NewPeerError(InvalidMsgCode, "%d", code)
- }
- // setProtocols starts all subprotocols shared with the
- // remote peer. the protocols must be sorted alphabetically.
- func (m *messenger) setRemoteProtocols(protocols []string) {
- m.protocolLock.Lock()
- defer m.protocolLock.Unlock()
- offset := baseProtocolOffset
- for _, name := range protocols {
- inst, ok := m.handlers[name]
- if !ok {
- continue // not handled
- }
- m.protocols[name] = m.startProto(offset, name, inst)
- offset += inst.Offset()
- }
- }
- // writeProtoMsg sends the given message on behalf of the given named protocol.
- func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
- m.protocolLock.RLock()
- proto, ok := m.protocols[protoName]
- m.protocolLock.RUnlock()
- if !ok {
- return fmt.Errorf("protocol %s not handled by peer", protoName)
- }
- if msg.Code >= proto.maxcode {
- return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
- }
- msg.Code += proto.offset
- return m.writeMsg(msg)
- }
- // writeMsg writes a message to the connection.
- func (m *messenger) writeMsg(msg Msg) error {
- m.writeMu.Lock()
- defer m.writeMu.Unlock()
- if err := writeMsg(m.bufconn, msg); err != nil {
- return err
- }
- return m.bufconn.Flush()
- }
|