瀏覽代碼

p2p: fix issues found during review

Felix Lange 11 年之前
父節點
當前提交
7149191dd9
共有 4 個文件被更改,包括 96 次插入53 次删除
  1. 1 1
      p2p/message.go
  2. 7 7
      p2p/messenger.go
  3. 87 41
      p2p/messenger_test.go
  4. 1 4
      p2p/protocol.go

+ 1 - 1
p2p/message.go

@@ -98,7 +98,7 @@ type byteReader interface {
 	io.ByteReader
 }
 
-// readMsg reads a message header.
+// readMsg reads a message header from r.
 func readMsg(r byteReader) (msg Msg, err error) {
 	// read magic and payload size
 	start := make([]byte, 8)

+ 7 - 7
p2p/messenger.go

@@ -11,7 +11,7 @@ import (
 	"time"
 )
 
-type Handlers map[string]func() Protocol
+type Handlers map[string]Protocol
 
 type proto struct {
 	in              chan Msg
@@ -23,6 +23,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
 	if msg.Code >= rw.maxcode {
 		return NewPeerError(InvalidMsgCode, "not handled")
 	}
+	msg.Code += rw.offset
 	return rw.messenger.writeMsg(msg)
 }
 
@@ -31,12 +32,13 @@ func (rw *proto) ReadMsg() (Msg, error) {
 	if !ok {
 		return msg, io.EOF
 	}
+	msg.Code -= rw.offset
 	return msg, nil
 }
 
-// eofSignal is used to 'lend' the network connection
-// to a protocol. when the protocol's read loop has read the
-// whole payload, the done channel is closed.
+// eofSignal wraps a reader with eof signaling.
+// the eof channel is closed when the wrapped reader
+// reaches EOF.
 type eofSignal struct {
 	wrapped io.Reader
 	eof     chan struct{}
@@ -119,7 +121,6 @@ func (m *messenger) readLoop() {
 			m.err <- err
 			return
 		}
-		msg.Code -= proto.offset
 		if msg.Size <= wholePayloadSize {
 			// optimization: msg is small enough, read all
 			// of it and move on to the next message
@@ -185,11 +186,10 @@ func (m *messenger) setRemoteProtocols(protocols []string) {
 	defer m.protocolLock.Unlock()
 	offset := baseProtocolOffset
 	for _, name := range protocols {
-		protocolFunc, ok := m.handlers[name]
+		inst, ok := m.handlers[name]
 		if !ok {
 			continue // not handled
 		}
-		inst := protocolFunc()
 		m.protocols[name] = m.startProto(offset, name, inst)
 		offset += inst.Offset()
 	}

+ 87 - 41
p2p/messenger_test.go

@@ -11,14 +11,14 @@ import (
 	"testing"
 	"time"
 
-	"github.com/ethereum/go-ethereum/ethutil"
+	logpkg "github.com/ethereum/go-ethereum/logger"
 )
 
 func init() {
-	ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel))
+	logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
 }
 
-func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
+func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
 	conn1, conn2 := net.Pipe()
 	id := NewSimpleClientIdentity("test", "0", "0", "public key")
 	server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
@@ -33,7 +33,7 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error {
 		return fmt.Errorf("read error: %v", err)
 	}
 	if msg.Code != handshakeMsg {
-		return fmt.Errorf("first message should be handshake, got %x", msg.Code)
+		return fmt.Errorf("first message should be handshake, got %d", msg.Code)
 	}
 	if err := msg.Discard(); err != nil {
 		return err
@@ -44,56 +44,102 @@ func performTestHandshake(r *bufio.Reader, w io.Writer) error {
 	return writeMsg(w, msg)
 }
 
-type testMsg struct {
-	code MsgCode
-	data *ethutil.Value
+type testProtocol struct {
+	offset MsgCode
+	f      func(MsgReadWriter)
 }
 
-type testProto struct {
-	recv chan testMsg
+func (p *testProtocol) Offset() MsgCode {
+	return p.offset
 }
 
-func (*testProto) Offset() MsgCode { return 5 }
-
-func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error {
-	return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error {
-		logger.Debugf("testprotocol got msg: %d\n", code)
-		tp.recv <- testMsg{code, data}
-		return nil
-	})
+func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
+	p.f(rw)
+	return nil
 }
 
 func TestRead(t *testing.T) {
-	testProtocol := &testProto{make(chan testMsg)}
-	handlers := Handlers{"a": func() Protocol { return testProtocol }}
-	net, peer, mess := setupMessenger(handlers)
-	bufr := bufio.NewReader(net)
+	done := make(chan struct{})
+	handlers := Handlers{
+		"a": &testProtocol{5, func(rw MsgReadWriter) {
+			msg, err := rw.ReadMsg()
+			if err != nil {
+				t.Errorf("read error: %v", err)
+			}
+			if msg.Code != 2 {
+				t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
+			}
+			data, err := msg.Data()
+			if err != nil {
+				t.Errorf("data decoding error: %v", err)
+			}
+			expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
+			if !reflect.DeepEqual(data.Slice(), expdata) {
+				t.Errorf("incorrect msg data %#v", data.Slice())
+			}
+			close(done)
+		}},
+	}
+
+	net, peer, m := testMessenger(handlers)
 	defer peer.Stop()
+	bufr := bufio.NewReader(net)
 	if err := performTestHandshake(bufr, net); err != nil {
 		t.Fatalf("handshake failed: %v", err)
 	}
+	m.setRemoteProtocols([]string{"a"})
 
-	mess.setRemoteProtocols([]string{"a"})
-	writeMsg(net, NewMsg(17, uint32(1), "000"))
+	writeMsg(net, NewMsg(18, 1, "000"))
 	select {
-	case msg := <-testProtocol.recv:
-		if msg.code != 1 {
-			t.Errorf("incorrect msg code %d relayed to protocol", msg.code)
-		}
-		expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
-		if !reflect.DeepEqual(msg.data.Slice(), expdata) {
-			t.Errorf("incorrect msg data %#v", msg.data.Slice())
-		}
+	case <-done:
 	case <-time.After(2 * time.Second):
 		t.Errorf("receive timeout")
 	}
 }
 
-func TestWriteProtoMsg(t *testing.T) {
-	handlers := make(Handlers)
-	testProtocol := &testProto{recv: make(chan testMsg, 1)}
-	handlers["a"] = func() Protocol { return testProtocol }
-	net, peer, mess := setupMessenger(handlers)
+func TestWriteFromProto(t *testing.T) {
+	handlers := Handlers{
+		"a": &testProtocol{2, func(rw MsgReadWriter) {
+			if err := rw.WriteMsg(NewMsg(2)); err == nil {
+				t.Error("expected error for out-of-range msg code, got nil")
+			}
+			if err := rw.WriteMsg(NewMsg(1)); err != nil {
+				t.Errorf("write error: %v", err)
+			}
+		}},
+	}
+	net, peer, mess := testMessenger(handlers)
+	defer peer.Stop()
+	bufr := bufio.NewReader(net)
+	if err := performTestHandshake(bufr, net); err != nil {
+		t.Fatalf("handshake failed: %v", err)
+	}
+	mess.setRemoteProtocols([]string{"a"})
+
+	msg, err := readMsg(bufr)
+	if err != nil {
+		t.Errorf("read error: %v")
+	}
+	if msg.Code != 17 {
+		t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
+	}
+}
+
+var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
+	for {
+		msg, err := rw.ReadMsg()
+		if err != nil {
+			return
+		}
+		if err = msg.Discard(); err != nil {
+			return
+		}
+	}
+}}
+
+func TestMessengerWriteProtoMsg(t *testing.T) {
+	handlers := Handlers{"a": discardProto}
+	net, peer, mess := testMessenger(handlers)
 	defer peer.Stop()
 	bufr := bufio.NewReader(net)
 	if err := performTestHandshake(bufr, net); err != nil {
@@ -120,13 +166,13 @@ func TestWriteProtoMsg(t *testing.T) {
 			read <- msg
 		}
 	}()
-	if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil {
+	if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
 		t.Errorf("expect no error for known protocol: %v", err)
 	}
 	select {
 	case msg := <-read:
-		if msg.Code != 19 {
-			t.Errorf("wrong code, got %d, expected %d", msg.Code, 19)
+		if msg.Code != 16 {
+			t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
 		}
 		msg.Discard()
 	case err := <-readerr:
@@ -135,7 +181,7 @@ func TestWriteProtoMsg(t *testing.T) {
 }
 
 func TestPulse(t *testing.T) {
-	net, peer, _ := setupMessenger(nil)
+	net, peer, _ := testMessenger(nil)
 	defer peer.Stop()
 	bufr := bufio.NewReader(net)
 	if err := performTestHandshake(bufr, net); err != nil {
@@ -149,7 +195,7 @@ func TestPulse(t *testing.T) {
 	}
 	after := time.Now()
 	if msg.Code != pingMsg {
-		t.Errorf("expected ping message, got %x", msg.Code)
+		t.Errorf("expected ping message, got %d", msg.Code)
 	}
 	if d := after.Sub(before); d < pingTimeout {
 		t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)

+ 1 - 4
p2p/protocol.go

@@ -143,9 +143,6 @@ func (d DiscReason) String() string {
 	return discReasonToString[d]
 }
 
-func (bp *baseProtocol) Ping() {
-}
-
 func (bp *baseProtocol) Offset() MsgCode {
 	return baseProtocolOffset
 }
@@ -287,7 +284,7 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
 
 	// self connect detection
 	if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
-		return NewPeerError(PubkeyForbidden, "not allowed to connect to bp")
+		return NewPeerError(PubkeyForbidden, "not allowed to connect to self")
 	}
 
 	// register pubkey on server. this also sets the pubkey on the peer (need lock)