Эх сурвалжийг харах

p2p: add trust check to handshake, test privileged connectivity

Conflicts:
	p2p/server_test.go
Péter Szilágyi 10 жил өмнө
parent
commit
1528dbc171

+ 7 - 7
p2p/handshake.go

@@ -70,21 +70,21 @@ type protoHandshake struct {
 // If dial is non-nil, the connection the local node is the initiator.
 // If atcap is true, the connection will be disconnected with DiscTooManyPeers
 // after the key exchange.
-func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
+func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool, trust map[discover.NodeID]bool) (*conn, error) {
 	if dial == nil {
-		return setupInboundConn(fd, prv, our, atcap)
+		return setupInboundConn(fd, prv, our, atcap, trust)
 	} else {
-		return setupOutboundConn(fd, prv, our, dial, atcap)
+		return setupOutboundConn(fd, prv, our, dial, atcap, trust)
 	}
 }
 
-func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, atcap bool) (*conn, error) {
+func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, atcap bool, trust map[discover.NodeID]bool) (*conn, error) {
 	secrets, err := receiverEncHandshake(fd, prv, nil)
 	if err != nil {
 		return nil, fmt.Errorf("encryption handshake failed: %v", err)
 	}
 	rw := newRlpxFrameRW(fd, secrets)
-	if atcap {
+	if atcap && !trust[secrets.RemoteID] {
 		SendItems(rw, discMsg, DiscTooManyPeers)
 		return nil, errors.New("we have too many peers")
 	}
@@ -99,13 +99,13 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, a
 	return &conn{rw, rhs}, nil
 }
 
-func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
+func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool, trust map[discover.NodeID]bool) (*conn, error) {
 	secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
 	if err != nil {
 		return nil, fmt.Errorf("encryption handshake failed: %v", err)
 	}
 	rw := newRlpxFrameRW(fd, secrets)
-	if atcap {
+	if atcap && !trust[secrets.RemoteID] {
 		SendItems(rw, discMsg, DiscTooManyPeers)
 		return nil, errors.New("we have too many peers")
 	}

+ 2 - 2
p2p/handshake_test.go

@@ -143,7 +143,7 @@ func TestSetupConn(t *testing.T) {
 	done := make(chan struct{})
 	go func() {
 		defer close(done)
-		conn0, err := setupConn(fd0, prv0, hs0, node1, false)
+		conn0, err := setupConn(fd0, prv0, hs0, node1, false, nil)
 		if err != nil {
 			t.Errorf("outbound side error: %v", err)
 			return
@@ -156,7 +156,7 @@ func TestSetupConn(t *testing.T) {
 		}
 	}()
 
-	conn1, err := setupConn(fd1, prv1, hs1, nil, false)
+	conn1, err := setupConn(fd1, prv1, hs1, nil, false, nil)
 	if err != nil {
 		t.Fatalf("inbound side error: %v", err)
 	}

+ 14 - 3
p2p/server.go

@@ -115,7 +115,7 @@ type Server struct {
 	peerWG sync.WaitGroup // active peer goroutines
 }
 
-type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, bool) (*conn, error)
+type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, bool, map[discover.NodeID]bool) (*conn, error)
 type newPeerHook func(*Peer)
 
 // Peers returns all connected peers.
@@ -140,7 +140,10 @@ func (srv *Server) PeerCount() int {
 
 // TrustPeer inserts a node into the list of privileged nodes.
 func (srv *Server) TrustPeer(node *discover.Node) {
-	srv.trustDial <- node
+	srv.lock.Lock()
+	defer srv.lock.Unlock()
+
+	srv.trusts[node.ID] = node
 }
 
 // Broadcast sends an RLP-encoded message to all connected peers.
@@ -470,10 +473,18 @@ func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
 	// returns during that exchange need to call peerWG.Done because
 	// the callers of startPeer added the peer to the wait group already.
 	fd.SetDeadline(time.Now().Add(handshakeTimeout))
+
+	// Check capacity and trust list
 	srv.lock.RLock()
 	atcap := len(srv.peers) == srv.MaxPeers
+
+	trust := make(map[discover.NodeID]bool)
+	for id, _ := range srv.trusts {
+		trust[id] = true
+	}
 	srv.lock.RUnlock()
-	conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, atcap)
+
+	conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, atcap, trust)
 	if err != nil {
 		fd.Close()
 		glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)

+ 65 - 3
p2p/server_test.go

@@ -22,7 +22,7 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
 		ListenAddr:  "127.0.0.1:0",
 		PrivateKey:  newkey(),
 		newPeerHook: pf,
-		setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool) (*conn, error) {
+		setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, atcap bool, trust map[discover.NodeID]bool) (*conn, error) {
 			id := randomID()
 			rw := newRlpxFrameRW(fd, secrets{
 				MAC:        zero16,
@@ -102,7 +102,7 @@ func TestServerDial(t *testing.T) {
 
 	// tell the server to connect
 	tcpAddr := listener.Addr().(*net.TCPAddr)
-	srv.trustDial <-&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port}
+	srv.trustDial <- &discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port}
 
 	select {
 	case conn := <-accepted:
@@ -200,7 +200,7 @@ func TestServerDisconnectAtCap(t *testing.T) {
 		// Run the handshakes just like a real peer would.
 		key := newkey()
 		hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
-		_, err = setupConn(conn, key, hs, srv.Self(), false)
+		_, err = setupConn(conn, key, hs, srv.Self(), false, nil)
 		if i == nconns-1 {
 			// When handling the last connection, the server should
 			// disconnect immediately instead of running the protocol
@@ -219,6 +219,68 @@ func TestServerDisconnectAtCap(t *testing.T) {
 	}
 }
 
+// Tests that trusted peers and can connect above max peer caps.
+func TestServerTrustedPeers(t *testing.T) {
+	defer testlog(t).detach()
+
+	// Create a test server with limited connection slots
+	started := make(chan *Peer)
+	server := &Server{
+		ListenAddr:  "127.0.0.1:0",
+		PrivateKey:  newkey(),
+		MaxPeers:    3,
+		NoDial:      true,
+		newPeerHook: func(p *Peer) { started <- p },
+	}
+	if err := server.Start(); err != nil {
+		t.Fatal(err)
+	}
+	defer server.Stop()
+
+	// Fill up all the slots on the server
+	dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
+	for i := 0; i < server.MaxPeers; i++ {
+		// Establish a new connection
+		conn, err := dialer.Dial("tcp", server.ListenAddr)
+		if err != nil {
+			t.Fatalf("conn %d: dial error: %v", i, err)
+		}
+		defer conn.Close()
+
+		// Run the handshakes just like a real peer would, and wait for completion
+		key := newkey()
+		shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
+		if _, err = setupConn(conn, key, shake, server.Self(), false, nil); err != nil {
+			t.Fatalf("conn %d: unexpected error: %v", i, err)
+		}
+		<-started
+	}
+	// Inject a trusted node and dial that (we'll connect from this end, don't need IP setup)
+	key := newkey()
+	trusted := &discover.Node{
+		ID: discover.PubkeyID(&key.PublicKey),
+	}
+	server.TrustPeer(trusted)
+
+	conn, err := dialer.Dial("tcp", server.ListenAddr)
+	if err != nil {
+		t.Fatalf("trusted node: dial error: %v", err)
+	}
+	defer conn.Close()
+
+	shake := &protoHandshake{Version: baseProtocolVersion, ID: trusted.ID}
+	if _, err = setupConn(conn, key, shake, server.Self(), false, nil); err != nil {
+		t.Fatalf("trusted node: unexpected error: %v", err)
+	}
+	select {
+	case <-started:
+		// Ok, trusted peer accepted
+
+	case <-time.After(100 * time.Millisecond):
+		t.Fatalf("trusted node timeout")
+	}
+}
+
 func newkey() *ecdsa.PrivateKey {
 	key, err := crypto.GenerateKey()
 	if err != nil {