Browse Source

p2p: use package rlp for baseProtocol

Felix Lange 11 years ago
parent
commit
6049fcd52a
4 changed files with 71 additions and 58 deletions
  1. 13 5
      p2p/message.go
  2. 1 1
      p2p/message_test.go
  3. 1 1
      p2p/peer_test.go
  4. 56 51
      p2p/protocol.go

+ 13 - 5
p2p/message.go

@@ -41,14 +41,22 @@ func encodePayload(params ...interface{}) []byte {
 	return buf.Bytes()
 }
 
-// Data returns the decoded RLP payload items in a message.
-func (msg Msg) Data() (*ethutil.Value, error) {
-	s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
+// Value returns the decoded RLP payload items in a message.
+func (msg Msg) Value() (*ethutil.Value, error) {
 	var v []interface{}
-	err := s.Decode(&v)
+	err := msg.Decode(&v)
 	return ethutil.NewValue(v), err
 }
 
+// Decode parse the RLP content of a message into
+// the given value, which must be a pointer.
+//
+// For the decoding rules, please see package rlp.
+func (msg Msg) Decode(val interface{}) error {
+	s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
+	return s.Decode(val)
+}
+
 // Discard reads any remaining payload data into a black hole.
 func (msg Msg) Discard() error {
 	_, err := io.Copy(ioutil.Discard, msg.Payload)
@@ -91,7 +99,7 @@ func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Valu
 		if msg.Size > maxsize {
 			return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
 		}
-		value, err := msg.Data()
+		value, err := msg.Value()
 		if err != nil {
 			return err
 		}

+ 1 - 1
p2p/message_test.go

@@ -42,7 +42,7 @@ func TestEncodeDecodeMsg(t *testing.T) {
 	if decmsg.Size != 5 {
 		t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
 	}
-	data, err := decmsg.Data()
+	data, err := decmsg.Value()
 	if err != nil {
 		t.Fatalf("first payload item decode error: %v", err)
 	}

+ 1 - 1
p2p/peer_test.go

@@ -53,7 +53,7 @@ func TestPeerProtoReadMsg(t *testing.T) {
 			if msg.Code != 2 {
 				t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
 			}
-			data, err := msg.Data()
+			data, err := msg.Value()
 			if err != nil {
 				t.Errorf("data decoding error: %v", err)
 			}

+ 56 - 51
p2p/protocol.go

@@ -2,7 +2,6 @@ package p2p
 
 import (
 	"bytes"
-	"net"
 	"time"
 
 	"github.com/ethereum/go-ethereum/ethutil"
@@ -90,30 +89,18 @@ type baseProtocol struct {
 
 func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
 	bp := &baseProtocol{rw, peer}
-
-	// do handshake
-	if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
-		return err
-	}
-	msg, err := rw.ReadMsg()
-	if err != nil {
+	if err := bp.doHandshake(rw); err != nil {
 		return err
 	}
-	if msg.Code != handshakeMsg {
-		return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
-	}
-	data, err := msg.Data()
-	if err != nil {
-		return newPeerError(errInvalidMsg, "%v", err)
-	}
-	if err := bp.handleHandshake(data); err != nil {
-		return err
-	}
-
 	// run main loop
 	quit := make(chan error, 1)
 	go func() {
-		quit <- MsgLoop(rw, baseProtocolMaxMsgSize, bp.handle)
+		for {
+			if err := bp.handle(rw); err != nil {
+				quit <- err
+				break
+			}
+		}
 	}()
 	return bp.loop(quit)
 }
@@ -151,13 +138,27 @@ func (bp *baseProtocol) loop(quit <-chan error) error {
 	return err
 }
 
-func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
-	switch code {
+func (bp *baseProtocol) handle(rw MsgReadWriter) error {
+	msg, err := rw.ReadMsg()
+	if err != nil {
+		return err
+	}
+	if msg.Size > baseProtocolMaxMsgSize {
+		return newPeerError(errMisc, "message too big")
+	}
+	// make sure that the payload has been fully consumed
+	defer msg.Discard()
+
+	switch msg.Code {
 	case handshakeMsg:
 		return newPeerError(errProtocolBreach, "extra handshake received")
 
 	case discMsg:
-		bp.peer.Disconnect(DiscReason(data.Get(0).Uint()))
+		var reason DiscReason
+		if err := msg.Decode(&reason); err != nil {
+			return err
+		}
+		bp.peer.Disconnect(reason)
 		return nil
 
 	case pingMsg:
@@ -178,35 +179,45 @@ func (bp *baseProtocol) handle(code uint64, data *ethutil.Value) error {
 		}
 
 	case peersMsg:
-		bp.handlePeers(data)
+		var peers []*peerAddr
+		if err := msg.Decode(&peers); err != nil {
+			return err
+		}
+		for _, addr := range peers {
+			bp.peer.Debugf("received peer suggestion: %v", addr)
+			bp.peer.newPeerAddr <- addr
+		}
 
 	default:
-		return newPeerError(errInvalidMsgCode, "unknown message code %v", code)
+		return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
 	}
 	return nil
 }
 
-func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
-	it := data.NewIterator()
-	for it.Next() {
-		addr := &peerAddr{
-			IP:     net.IP(it.Value().Get(0).Bytes()),
-			Port:   it.Value().Get(1).Uint(),
-			Pubkey: it.Value().Get(2).Bytes(),
-		}
-		bp.peer.Debugf("received peer suggestion: %v", addr)
-		bp.peer.newPeerAddr <- addr
+func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
+	// send our handshake
+	if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
+		return err
+	}
+
+	// read and handle remote handshake
+	msg, err := rw.ReadMsg()
+	if err != nil {
+		return err
+	}
+	if msg.Code != handshakeMsg {
+		return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
+	}
+	if msg.Size > baseProtocolMaxMsgSize {
+		return newPeerError(errMisc, "message too big")
 	}
-}
 
-func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
-	hs := handshake{
-		Version:    c.Get(0).Uint(),
-		ID:         c.Get(1).Str(),
-		Caps:       nil, // decoded below
-		ListenPort: c.Get(3).Uint(),
-		NodeID:     c.Get(4).Bytes(),
+	var hs handshake
+	if err := msg.Decode(&hs); err != nil {
+		return err
 	}
+
+	// validate handshake info
 	if hs.Version != baseProtocolVersion {
 		return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
 			baseProtocolVersion, hs.Version)
@@ -228,14 +239,8 @@ func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
 	if err := bp.peer.pubkeyHook(pa); err != nil {
 		return newPeerError(errPubkeyForbidden, "%v", err)
 	}
-	capsIt := c.Get(2).NewIterator()
-	for capsIt.Next() {
-		cap := capsIt.Value()
-		name := cap.Get(0).Str()
-		if name != "" {
-			hs.Caps = append(hs.Caps, Cap{Name: name, Version: uint(cap.Get(1).Uint())})
-		}
-	}
+
+	// TODO: remove Caps with empty name
 
 	var addr *peerAddr
 	if hs.ListenPort != 0 {