Преглед на файлове

p2p: disable encryption handshake

The diff is a bit bigger than expected because the protocol handshake
logic has moved out of Peer. This is necessary because the protocol
handshake will have custom framing in the final protocol.
Felix Lange преди 10 години
родител
ревизия
73f94f3755
променени са 7 файла, в които са добавени 273 реда и са изтрити 313 реда
  1. 104 33
      p2p/handshake.go
  2. 60 3
      p2p/handshake_test.go
  3. 1 1
      p2p/message.go
  4. 59 170
      p2p/peer.go
  5. 20 85
      p2p/peer_test.go
  6. 22 16
      p2p/server.go
  7. 7 5
      p2p/server_test.go

+ 104 - 33
p2p/crypto.go → p2p/handshake.go

@@ -1,21 +1,20 @@
 package p2p
 
 import (
-	// "binary"
 	"crypto/ecdsa"
 	"crypto/rand"
+	"errors"
 	"fmt"
 	"io"
+	"net"
 
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/crypto/ecies"
 	"github.com/ethereum/go-ethereum/crypto/secp256k1"
-	ethlogger "github.com/ethereum/go-ethereum/logger"
 	"github.com/ethereum/go-ethereum/p2p/discover"
+	"github.com/ethereum/go-ethereum/rlp"
 )
 
-var clogger = ethlogger.NewLogger("CRYPTOID")
-
 const (
 	sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
 	sigLen = 65 // elliptic S256
@@ -30,26 +29,76 @@ const (
 	rHSLen     = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
 )
 
-type hexkey []byte
+type conn struct {
+	*frameRW
+	*protoHandshake
+}
 
-func (self hexkey) String() string {
-	return fmt.Sprintf("(%d) %x", len(self), []byte(self))
+func newConn(fd net.Conn, hs *protoHandshake) *conn {
+	return &conn{newFrameRW(fd, msgWriteTimeout), hs}
 }
 
-func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) (
-	remoteID discover.NodeID,
-	sessionToken []byte,
-	err error,
-) {
+// encHandshake represents information about the remote end
+// of a connection that is negotiated during the encryption handshake.
+type encHandshake struct {
+	ID         discover.NodeID
+	IngressMAC []byte
+	EgressMAC  []byte
+	Token      []byte
+}
+
+// protoHandshake is the RLP structure of the protocol handshake.
+type protoHandshake struct {
+	Version    uint64
+	Name       string
+	Caps       []Cap
+	ListenPort uint64
+	ID         discover.NodeID
+}
+
+// setupConn starts a protocol session on the given connection.
+// It runs the encryption handshake and the protocol handshake.
+// If dial is non-nil, the connection the local node is the initiator.
+func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
 	if dial == nil {
-		var remotePubkey []byte
-		sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
-		copy(remoteID[:], remotePubkey)
+		return setupInboundConn(fd, prv, our)
 	} else {
-		remoteID = dial.ID
-		sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
+		return setupOutboundConn(fd, prv, our, dial)
+	}
+}
+
+func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
+	// var remotePubkey []byte
+	// sessionToken, remotePubkey, err = inboundEncHandshake(fd, prv, nil)
+	// copy(remoteID[:], remotePubkey)
+
+	rw := newFrameRW(fd, msgWriteTimeout)
+	rhs, err := readProtocolHandshake(rw, our)
+	if err != nil {
+		return nil, err
+	}
+	if err := writeProtocolHandshake(rw, our); err != nil {
+		return nil, fmt.Errorf("protocol write error: %v", err)
+	}
+	return &conn{rw, rhs}, nil
+}
+
+func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+	// remoteID = dial.ID
+	// sessionToken, err = outboundEncHandshake(fd, prv, remoteID[:], nil)
+
+	rw := newFrameRW(fd, msgWriteTimeout)
+	if err := writeProtocolHandshake(rw, our); err != nil {
+		return nil, fmt.Errorf("protocol write error: %v", err)
 	}
-	return remoteID, sessionToken, err
+	rhs, err := readProtocolHandshake(rw, our)
+	if err != nil {
+		return nil, fmt.Errorf("protocol handshake read error: %v", err)
+	}
+	if rhs.ID != dial.ID {
+		return nil, errors.New("dialed node id mismatch")
+	}
+	return &conn{rw, rhs}, nil
 }
 
 // outboundEncHandshake negotiates a session token on conn.
@@ -66,18 +115,9 @@ func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePu
 	if err != nil {
 		return nil, err
 	}
-	if sessionToken != nil {
-		clogger.Debugf("session-token: %v", hexkey(sessionToken))
-	}
-
-	clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
-	clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
-	randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
-	clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
 	if _, err = conn.Write(auth); err != nil {
 		return nil, err
 	}
-	clogger.Debugf("initiator handshake: %v", hexkey(auth))
 
 	response := make([]byte, rHSLen)
 	if _, err = io.ReadFull(conn, response); err != nil {
@@ -88,9 +128,6 @@ func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePu
 		return nil, err
 	}
 
-	clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
-	remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
-	clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
 	return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
 }
 
@@ -221,12 +258,9 @@ func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionTo
 	if err != nil {
 		return nil, nil, err
 	}
-	clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
-	clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
 	if _, err = conn.Write(response); err != nil {
 		return nil, nil, err
 	}
-	clogger.Debugf("receiver handshake:\n%v", hexkey(response))
 	token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
 	return token, remotePubKey, err
 }
@@ -361,3 +395,40 @@ func xor(one, other []byte) (xor []byte) {
 	}
 	return xor
 }
+
+func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
+	return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
+}
+
+func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
+	// read and handle remote handshake
+	msg, err := r.ReadMsg()
+	if err != nil {
+		return nil, err
+	}
+	if msg.Code == discMsg {
+		// disconnect before protocol handshake is valid according to the
+		// spec and we send it ourself if Server.addPeer fails.
+		var reason DiscReason
+		rlp.Decode(msg.Payload, &reason)
+		return nil, discRequestedError(reason)
+	}
+	if msg.Code != handshakeMsg {
+		return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
+	}
+	if msg.Size > baseProtocolMaxMsgSize {
+		return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
+	}
+	var hs protoHandshake
+	if err := msg.Decode(&hs); err != nil {
+		return nil, err
+	}
+	// validate handshake info
+	if hs.Version != our.Version {
+		return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
+	}
+	if (hs.ID == discover.NodeID{}) {
+		return nil, newPeerError(errPubkeyInvalid, "missing")
+	}
+	return &hs, nil
+}

+ 60 - 3
p2p/crypto_test.go → p2p/handshake_test.go

@@ -5,10 +5,12 @@ import (
 	"crypto/ecdsa"
 	"crypto/rand"
 	"net"
+	"reflect"
 	"testing"
 
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/crypto/ecies"
+	"github.com/ethereum/go-ethereum/p2p/discover"
 )
 
 func TestPublicKeyEncoding(t *testing.T) {
@@ -91,14 +93,14 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
 	if err != nil {
 		t.Errorf("%v", err)
 	}
-	t.Logf("-> %v", hexkey(auth))
+	// t.Logf("-> %v", hexkey(auth))
 
 	// receiver reads auth and responds with response
 	response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
 	if err != nil {
 		t.Errorf("%v", err)
 	}
-	t.Logf("<- %v\n", hexkey(response))
+	// t.Logf("<- %v\n", hexkey(response))
 
 	// initiator reads receiver's response and the key exchange completes
 	recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
@@ -132,7 +134,7 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
 	}
 }
 
-func TestHandshake(t *testing.T) {
+func TestEncHandshake(t *testing.T) {
 	defer testlog(t).detach()
 
 	prv0, _ := crypto.GenerateKey()
@@ -165,3 +167,58 @@ func TestHandshake(t *testing.T) {
 		t.Error("session token mismatch")
 	}
 }
+
+func TestSetupConn(t *testing.T) {
+	prv0, _ := crypto.GenerateKey()
+	prv1, _ := crypto.GenerateKey()
+	node0 := &discover.Node{
+		ID:      discover.PubkeyID(&prv0.PublicKey),
+		IP:      net.IP{1, 2, 3, 4},
+		TCPPort: 33,
+	}
+	node1 := &discover.Node{
+		ID:      discover.PubkeyID(&prv1.PublicKey),
+		IP:      net.IP{5, 6, 7, 8},
+		TCPPort: 44,
+	}
+	hs0 := &protoHandshake{
+		Version: baseProtocolVersion,
+		ID:      node0.ID,
+		Caps:    []Cap{{"a", 0}, {"b", 2}},
+	}
+	hs1 := &protoHandshake{
+		Version: baseProtocolVersion,
+		ID:      node1.ID,
+		Caps:    []Cap{{"c", 1}, {"d", 3}},
+	}
+	fd0, fd1 := net.Pipe()
+
+	done := make(chan struct{})
+	go func() {
+		defer close(done)
+		conn0, err := setupConn(fd0, prv0, hs0, node1)
+		if err != nil {
+			t.Errorf("outbound side error: %v", err)
+			return
+		}
+		if conn0.ID != node1.ID {
+			t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
+		}
+		if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
+			t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
+		}
+	}()
+
+	conn1, err := setupConn(fd1, prv1, hs1, nil)
+	if err != nil {
+		t.Fatalf("inbound side error: %v", err)
+	}
+	if conn1.ID != node0.ID {
+		t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
+	}
+	if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
+		t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
+	}
+
+	<-done
+}

+ 1 - 1
p2p/message.go

@@ -197,7 +197,7 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) {
 		return msg, err
 	}
 	if !bytes.HasPrefix(start, magicToken) {
-		return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
+		return msg, fmt.Errorf("bad magic token %x", start[:4])
 	}
 	size := binary.BigEndian.Uint32(start[4:])
 

+ 59 - 170
p2p/peer.go

@@ -33,37 +33,14 @@ const (
 	peersMsg     = 0x05
 )
 
-// handshake is the RLP structure of the protocol handshake.
-type handshake struct {
-	Version    uint64
-	Name       string
-	Caps       []Cap
-	ListenPort uint64
-	NodeID     discover.NodeID
-}
-
 // Peer represents a connected remote node.
 type Peer struct {
 	// Peers have all the log methods.
 	// Use them to display messages related to the peer.
 	*logger.Logger
 
-	infoMu sync.Mutex
-	name   string
-	caps   []Cap
-
-	ourID, remoteID *discover.NodeID
-	ourName         string
-
-	rw *frameRW
-
-	// These fields maintain the running protocols.
-	protocols []Protocol
-	runlock   sync.RWMutex // protects running
-	running   map[string]*proto
-
-	// disables protocol handshake, for testing
-	noHandshake bool
+	rw      *conn
+	running map[string]*protoRW
 
 	protoWG  sync.WaitGroup
 	protoErr chan error
@@ -73,36 +50,27 @@ type Peer struct {
 
 // NewPeer returns a peer for testing purposes.
 func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
-	conn, _ := net.Pipe()
-	peer := newPeer(conn, nil, "", nil, &id)
-	peer.setHandshakeInfo(name, caps)
+	pipe, _ := net.Pipe()
+	conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
+	peer := newPeer(conn, nil)
 	close(peer.closed) // ensures Disconnect doesn't block
 	return peer
 }
 
 // ID returns the node's public key.
 func (p *Peer) ID() discover.NodeID {
-	return *p.remoteID
+	return p.rw.ID
 }
 
 // Name returns the node name that the remote node advertised.
 func (p *Peer) Name() string {
-	// this needs a lock because the information is part of the
-	// protocol handshake.
-	p.infoMu.Lock()
-	name := p.name
-	p.infoMu.Unlock()
-	return name
+	return p.rw.Name
 }
 
 // Caps returns the capabilities (supported subprotocols) of the remote peer.
 func (p *Peer) Caps() []Cap {
-	// this needs a lock because the information is part of the
-	// protocol handshake.
-	p.infoMu.Lock()
-	caps := p.caps
-	p.infoMu.Unlock()
-	return caps
+	// TODO: maybe return copy
+	return p.rw.Caps
 }
 
 // RemoteAddr returns the remote address of the network connection.
@@ -126,30 +94,20 @@ func (p *Peer) Disconnect(reason DiscReason) {
 
 // String implements fmt.Stringer.
 func (p *Peer) String() string {
-	return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
+	return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
 }
 
-func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
-	logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
-	return &Peer{
-		Logger:    logger.NewLogger(logtag),
-		rw:        newFrameRW(conn, msgWriteTimeout),
-		ourID:     ourID,
-		ourName:   ourName,
-		remoteID:  remoteID,
-		protocols: protocols,
-		running:   make(map[string]*proto),
-		disc:      make(chan DiscReason),
-		protoErr:  make(chan error),
-		closed:    make(chan struct{}),
+func newPeer(conn *conn, protocols []Protocol) *Peer {
+	logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
+	p := &Peer{
+		Logger:   logger.NewLogger(logtag),
+		rw:       conn,
+		running:  matchProtocols(protocols, conn.Caps, conn),
+		disc:     make(chan DiscReason),
+		protoErr: make(chan error),
+		closed:   make(chan struct{}),
 	}
-}
-
-func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
-	p.infoMu.Lock()
-	p.name = name
-	p.caps = caps
-	p.infoMu.Unlock()
+	return p
 }
 
 func (p *Peer) run() DiscReason {
@@ -157,16 +115,9 @@ func (p *Peer) run() DiscReason {
 	defer p.closeProtocols()
 	defer close(p.closed)
 
+	p.startProtocols()
 	go func() { readErr <- p.readLoop() }()
 
-	if !p.noHandshake {
-		if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
-			p.DebugDetailf("Protocol handshake error: %v\n", err)
-			p.rw.Close()
-			return DiscProtocolError
-		}
-	}
-
 	// Wait for an error or disconnect.
 	var reason DiscReason
 	select {
@@ -206,11 +157,6 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
 }
 
 func (p *Peer) readLoop() error {
-	if !p.noHandshake {
-		if err := readProtocolHandshake(p, p.rw); err != nil {
-			return err
-		}
-	}
 	for {
 		msg, err := p.rw.ReadMsg()
 		if err != nil {
@@ -249,105 +195,51 @@ func (p *Peer) handle(msg Msg) error {
 	return nil
 }
 
-func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
-	// read and handle remote handshake
-	msg, err := rw.ReadMsg()
-	if err != nil {
-		return err
-	}
-	if msg.Code == discMsg {
-		// disconnect before protocol handshake is valid according to the
-		// spec and we send it ourself if Server.addPeer fails.
-		var reason DiscReason
-		rlp.Decode(msg.Payload, &reason)
-		return discRequestedError(reason)
-	}
-	if msg.Code != handshakeMsg {
-		return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
-	}
-	if msg.Size > baseProtocolMaxMsgSize {
-		return newPeerError(errInvalidMsg, "message too big")
-	}
-	var hs handshake
-	if err := msg.Decode(&hs); err != nil {
-		return err
-	}
-	// validate handshake info
-	if hs.Version != baseProtocolVersion {
-		return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
-			baseProtocolVersion, hs.Version)
-	}
-	if hs.NodeID == *p.remoteID {
-		return newPeerError(errPubkeyForbidden, "node ID mismatch")
-	}
-	// TODO: remove Caps with empty name
-	p.setHandshakeInfo(hs.Name, hs.Caps)
-	p.startSubprotocols(hs.Caps)
-	return nil
-}
-
-func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
-	var caps []interface{}
-	for _, proto := range ps {
-		caps = append(caps, proto.cap())
-	}
-	return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
-}
-
-// startProtocols starts matching named subprotocols.
-func (p *Peer) startSubprotocols(caps []Cap) {
+// matchProtocols creates structures for matching named subprotocols.
+func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
 	sort.Sort(capsByName(caps))
-	p.runlock.Lock()
-	defer p.runlock.Unlock()
 	offset := baseProtocolLength
+	result := make(map[string]*protoRW)
 outer:
 	for _, cap := range caps {
-		for _, proto := range p.protocols {
-			if proto.Name == cap.Name &&
-				proto.Version == cap.Version &&
-				p.running[cap.Name] == nil {
-				p.running[cap.Name] = p.startProto(offset, proto)
+		for _, proto := range protocols {
+			if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
+				result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
 				offset += proto.Length
 				continue outer
 			}
 		}
 	}
+	return result
 }
 
-func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
-	p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
-	rw := &proto{
-		name:    impl.Name,
-		in:      make(chan Msg),
-		offset:  offset,
-		maxcode: impl.Length,
-		w:       p.rw,
+func (p *Peer) startProtocols() {
+	for _, proto := range p.running {
+		proto := proto
+		p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
+		p.protoWG.Add(1)
+		go func() {
+			err := proto.Run(p, proto)
+			if err == nil {
+				p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
+				err = errors.New("protocol returned")
+			} else {
+				p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
+			}
+			select {
+			case p.protoErr <- err:
+			case <-p.closed:
+			}
+			p.protoWG.Done()
+		}()
 	}
-	p.protoWG.Add(1)
-	go func() {
-		err := impl.Run(p, rw)
-		if err == nil {
-			p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
-			err = errors.New("protocol returned")
-		} else {
-			p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
-		}
-		select {
-		case p.protoErr <- err:
-		case <-p.closed:
-		}
-		p.protoWG.Done()
-	}()
-	return rw
 }
 
 // getProto finds the protocol responsible for handling
 // the given message code.
-func (p *Peer) getProto(code uint64) (*proto, error) {
-	p.runlock.RLock()
-	defer p.runlock.RUnlock()
+func (p *Peer) getProto(code uint64) (*protoRW, error) {
 	for _, proto := range p.running {
-		if code >= proto.offset && code < proto.offset+proto.maxcode {
+		if code >= proto.offset && code < proto.offset+proto.Length {
 			return proto, nil
 		}
 	}
@@ -355,46 +247,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) {
 }
 
 func (p *Peer) closeProtocols() {
-	p.runlock.RLock()
 	for _, p := range p.running {
 		close(p.in)
 	}
-	p.runlock.RUnlock()
 	p.protoWG.Wait()
 }
 
 // writeProtoMsg sends the given message on behalf of the given named protocol.
 // this exists because of Server.Broadcast.
 func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
-	p.runlock.RLock()
 	proto, ok := p.running[protoName]
-	p.runlock.RUnlock()
 	if !ok {
 		return fmt.Errorf("protocol %s not handled by peer", protoName)
 	}
-	if msg.Code >= proto.maxcode {
+	if msg.Code >= proto.Length {
 		return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
 	}
 	msg.Code += proto.offset
 	return p.rw.WriteMsg(msg)
 }
 
-type proto struct {
-	name            string
-	in              chan Msg
-	maxcode, offset uint64
-	w               MsgWriter
+type protoRW struct {
+	Protocol
+
+	in     chan Msg
+	offset uint64
+	w      MsgWriter
 }
 
-func (rw *proto) WriteMsg(msg Msg) error {
-	if msg.Code >= rw.maxcode {
+func (rw *protoRW) WriteMsg(msg Msg) error {
+	if msg.Code >= rw.Length {
 		return newPeerError(errInvalidMsgCode, "not handled")
 	}
 	msg.Code += rw.offset
 	return rw.w.WriteMsg(msg)
 }
 
-func (rw *proto) ReadMsg() (Msg, error) {
+func (rw *protoRW) ReadMsg() (Msg, error) {
 	msg, ok := <-rw.in
 	if !ok {
 		return msg, io.EOF

+ 20 - 85
p2p/peer_test.go

@@ -6,11 +6,9 @@ import (
 	"io/ioutil"
 	"net"
 	"reflect"
-	"sort"
 	"testing"
 	"time"
 
-	"github.com/ethereum/go-ethereum/p2p/discover"
 	"github.com/ethereum/go-ethereum/rlp"
 )
 
@@ -23,6 +21,7 @@ var discard = Protocol{
 			if err != nil {
 				return err
 			}
+			fmt.Printf("discarding %d\n", msg.Code)
 			if err = msg.Discard(); err != nil {
 				return err
 			}
@@ -30,13 +29,20 @@ var discard = Protocol{
 	},
 }
 
-func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
-	conn1, conn2 := net.Pipe()
-	peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
-	peer.noHandshake = noHandshake
+func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
+	fd1, fd2 := net.Pipe()
+	hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
+	hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
+	for _, p := range protos {
+		hs1.Caps = append(hs1.Caps, p.cap())
+		hs2.Caps = append(hs2.Caps, p.cap())
+	}
+
+	peer := newPeer(newConn(fd1, hs1), protos)
 	errc := make(chan DiscReason, 1)
 	go func() { errc <- peer.run() }()
-	return newFrameRW(conn2, msgWriteTimeout), peer, errc
+
+	return newConn(fd2, hs2), peer, errc
 }
 
 func TestPeerProtoReadMsg(t *testing.T) {
@@ -61,9 +67,8 @@ func TestPeerProtoReadMsg(t *testing.T) {
 		},
 	}
 
-	rw, peer, errc := testPeer(true, []Protocol{proto})
+	rw, _, errc := testPeer([]Protocol{proto})
 	defer rw.Close()
-	peer.startSubprotocols([]Cap{proto.cap()})
 
 	EncodeMsg(rw, baseProtocolLength+2, 1)
 	EncodeMsg(rw, baseProtocolLength+3, 2)
@@ -100,9 +105,8 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
 		},
 	}
 
-	rw, peer, errc := testPeer(true, []Protocol{proto})
+	rw, _, errc := testPeer([]Protocol{proto})
 	defer rw.Close()
-	peer.startSubprotocols([]Cap{proto.cap()})
 
 	EncodeMsg(rw, 18, make([]byte, msgsize))
 	select {
@@ -130,9 +134,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
 			return nil
 		},
 	}
-	rw, peer, _ := testPeer(true, []Protocol{proto})
+	rw, _, _ := testPeer([]Protocol{proto})
 	defer rw.Close()
-	peer.startSubprotocols([]Cap{proto.cap()})
 
 	if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
 		t.Error(err)
@@ -142,9 +145,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
 func TestPeerWriteForBroadcast(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, peer, peerErr := testPeer(true, []Protocol{discard})
+	rw, peer, peerErr := testPeer([]Protocol{discard})
 	defer rw.Close()
-	peer.startSubprotocols([]Cap{discard.cap()})
 
 	// test write errors
 	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
@@ -160,7 +162,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
 	read := make(chan struct{})
 	go func() {
 		if err := expectMsg(rw, 16, nil); err != nil {
-			t.Error()
+			t.Error(err)
 		}
 		close(read)
 	}()
@@ -179,7 +181,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
 func TestPeerPing(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, _, _ := testPeer(true, nil)
+	rw, _, _ := testPeer(nil)
 	defer rw.Close()
 	if err := EncodeMsg(rw, pingMsg); err != nil {
 		t.Fatal(err)
@@ -192,7 +194,7 @@ func TestPeerPing(t *testing.T) {
 func TestPeerDisconnect(t *testing.T) {
 	defer testlog(t).detach()
 
-	rw, _, disc := testPeer(true, nil)
+	rw, _, disc := testPeer(nil)
 	defer rw.Close()
 	if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
 		t.Fatal(err)
@@ -206,73 +208,6 @@ func TestPeerDisconnect(t *testing.T) {
 	}
 }
 
-func TestPeerHandshake(t *testing.T) {
-	defer testlog(t).detach()
-
-	// remote has two matching protocols: a and c
-	remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
-	remoteID := randomID()
-	remote.ourID = &remoteID
-	remote.ourName = "remote peer"
-
-	start := make(chan string)
-	stop := make(chan struct{})
-	run := func(p *Peer, rw MsgReadWriter) error {
-		name := rw.(*proto).name
-		if name != "a" && name != "c" {
-			t.Errorf("protocol %q should not be started", name)
-		} else {
-			start <- name
-		}
-		<-stop
-		return nil
-	}
-	protocols := []Protocol{
-		{Name: "a", Version: 1, Length: 1, Run: run},
-		{Name: "b", Version: 2, Length: 1, Run: run},
-		{Name: "c", Version: 3, Length: 1, Run: run},
-		{Name: "d", Version: 4, Length: 1, Run: run},
-	}
-	rw, p, disc := testPeer(false, protocols)
-	p.remoteID = remote.ourID
-	defer rw.Close()
-
-	// run the handshake
-	remoteProtocols := []Protocol{protocols[0], protocols[2]}
-	if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
-		t.Fatalf("handshake write error: %v", err)
-	}
-	if err := readProtocolHandshake(remote, rw); err != nil {
-		t.Fatalf("handshake read error: %v", err)
-	}
-
-	// check that all protocols have been started
-	var started []string
-	for i := 0; i < 2; i++ {
-		select {
-		case name := <-start:
-			started = append(started, name)
-		case <-time.After(100 * time.Millisecond):
-		}
-	}
-	sort.Strings(started)
-	if !reflect.DeepEqual(started, []string{"a", "c"}) {
-		t.Errorf("wrong protocols started: %v", started)
-	}
-
-	// check that metadata has been set
-	if p.ID() != remoteID {
-		t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
-	}
-	if p.Name() != remote.ourName {
-		t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
-	}
-
-	close(stop)
-	expectMsg(rw, discMsg, nil)
-	t.Logf("disc reason: %v", <-disc)
-}
-
 func TestNewPeer(t *testing.T) {
 	name := "nodename"
 	caps := []Cap{{"foo", 2}, {"bar", 3}}

+ 22 - 16
p2p/server.go

@@ -5,7 +5,6 @@ import (
 	"crypto/ecdsa"
 	"errors"
 	"fmt"
-	"io"
 	"net"
 	"runtime"
 	"sync"
@@ -83,9 +82,11 @@ type Server struct {
 
 	// Hooks for testing. These are useful because we can inhibit
 	// the whole protocol stack.
-	handshakeFunc
+	setupFunc
 	newPeerHook
 
+	ourHandshake *protoHandshake
+
 	lock     sync.RWMutex
 	running  bool
 	listener net.Listener
@@ -99,7 +100,7 @@ type Server struct {
 	peerConnect chan *discover.Node
 }
 
-type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
+type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error)
 type newPeerHook func(*Peer)
 
 // Peers returns all connected peers.
@@ -170,8 +171,8 @@ func (srv *Server) Start() (err error) {
 	srv.peers = make(map[discover.NodeID]*Peer)
 	srv.peerConnect = make(chan *discover.Node)
 
-	if srv.handshakeFunc == nil {
-		srv.handshakeFunc = encHandshake
+	if srv.setupFunc == nil {
+		srv.setupFunc = setupConn
 	}
 	if srv.Blacklist == nil {
 		srv.Blacklist = NewBlacklist()
@@ -183,11 +184,17 @@ func (srv *Server) Start() (err error) {
 	}
 
 	// dial stuff
-	dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
+	ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
 	if err != nil {
 		return err
 	}
-	srv.ntab = dt
+	srv.ntab = ntab
+
+	srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self()}
+	for _, p := range srv.Protocols {
+		srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
+	}
+
 	if srv.Dialer == nil {
 		srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
 	}
@@ -347,18 +354,17 @@ func (srv *Server) findPeers() {
 	}
 }
 
-func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
+func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
 	// TODO: handle/store session token
-	conn.SetDeadline(time.Now().Add(handshakeTimeout))
-	remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
+	fd.SetDeadline(time.Now().Add(handshakeTimeout))
+	conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
 	if err != nil {
-		conn.Close()
-		srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
+		fd.Close()
+		srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
 		return
 	}
-	ourID := srv.ntab.Self()
-	p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
-	if ok, reason := srv.addPeer(remoteID, p); !ok {
+	p := newPeer(conn, srv.Protocols)
+	if ok, reason := srv.addPeer(conn.ID, p); !ok {
 		srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
 		p.politeDisconnect(reason)
 		return
@@ -394,7 +400,7 @@ func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
 
 func (srv *Server) removePeer(p *Peer) {
 	srv.lock.Lock()
-	delete(srv.peers, *p.remoteID)
+	delete(srv.peers, p.ID())
 	srv.lock.Unlock()
 	srv.peerWG.Done()
 }

+ 7 - 5
p2p/server_test.go

@@ -21,8 +21,12 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
 		ListenAddr:  "127.0.0.1:0",
 		PrivateKey:  newkey(),
 		newPeerHook: pf,
-		handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
-			return randomID(), nil, err
+		setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
+			id := randomID()
+			return &conn{
+				frameRW:        newFrameRW(fd, msgWriteTimeout),
+				protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
+			}, nil
 		},
 	}
 	if err := server.Start(); err != nil {
@@ -116,9 +120,7 @@ func TestServerBroadcast(t *testing.T) {
 
 	var connected sync.WaitGroup
 	srv := startTestServer(t, func(p *Peer) {
-		p.protocols = []Protocol{discard}
-		p.startSubprotocols([]Cap{discard.cap()})
-		p.noHandshake = true
+		p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw)
 		connected.Done()
 	})
 	defer srv.Stop()