Browse Source

[R4R] add timeout for stopping p2p server (#643)

* add timeout for stopping p2p server

* extend extension wait time

* add unit tests

* fix lint issue
yutianwu 3 years ago
parent
commit
7c1c8e2e88
6 changed files with 151 additions and 6 deletions
  1. 70 0
      eth/handler_eth_test.go
  2. 32 4
      eth/peerset.go
  3. 1 1
      eth/protocols/diff/handshake.go
  4. 10 0
      p2p/peer.go
  5. 15 1
      p2p/server.go
  6. 23 0
      p2p/server_test.go

+ 70 - 0
eth/handler_eth_test.go

@@ -239,6 +239,76 @@ func testForkIDSplit(t *testing.T, protocol uint) {
 func TestRecvTransactions65(t *testing.T) { testRecvTransactions(t, eth.ETH65) }
 func TestRecvTransactions66(t *testing.T) { testRecvTransactions(t, eth.ETH66) }
 
+func TestWaitDiffExtensionTimout(t *testing.T) {
+	t.Parallel()
+
+	// Create a message handler, configure it to accept transactions and watch them
+	handler := newTestHandler()
+	defer handler.close()
+
+	// Create a source peer to send messages through and a sink handler to receive them
+	_, p2pSink := p2p.MsgPipe()
+	defer p2pSink.Close()
+
+	protos := []p2p.Protocol{
+		{
+			Name:    "diff",
+			Version: 1,
+		},
+	}
+
+	sink := eth.NewPeer(eth.ETH67, p2p.NewPeerWithProtocols(enode.ID{2}, protos, "", []p2p.Cap{
+		{
+			Name:    "diff",
+			Version: 1,
+		},
+	}), p2pSink, nil)
+	defer sink.Close()
+
+	err := handler.handler.runEthPeer(sink, func(peer *eth.Peer) error {
+		return eth.Handle((*ethHandler)(handler.handler), peer)
+	})
+
+	if err == nil || err.Error() != "peer wait timeout" {
+		t.Fatalf("error should be `peer wait timeout`")
+	}
+}
+
+func TestWaitSnapExtensionTimout(t *testing.T) {
+	t.Parallel()
+
+	// Create a message handler, configure it to accept transactions and watch them
+	handler := newTestHandler()
+	defer handler.close()
+
+	// Create a source peer to send messages through and a sink handler to receive them
+	_, p2pSink := p2p.MsgPipe()
+	defer p2pSink.Close()
+
+	protos := []p2p.Protocol{
+		{
+			Name:    "snap",
+			Version: 1,
+		},
+	}
+
+	sink := eth.NewPeer(eth.ETH67, p2p.NewPeerWithProtocols(enode.ID{2}, protos, "", []p2p.Cap{
+		{
+			Name:    "snap",
+			Version: 1,
+		},
+	}), p2pSink, nil)
+	defer sink.Close()
+
+	err := handler.handler.runEthPeer(sink, func(peer *eth.Peer) error {
+		return eth.Handle((*ethHandler)(handler.handler), peer)
+	})
+
+	if err == nil || err.Error() != "peer wait timeout" {
+		t.Fatalf("error should be `peer wait timeout`")
+	}
+}
+
 func testRecvTransactions(t *testing.T, protocol uint) {
 	t.Parallel()
 

+ 32 - 4
eth/peerset.go

@@ -20,6 +20,7 @@ import (
 	"errors"
 	"math/big"
 	"sync"
+	"time"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/eth/downloader"
@@ -38,19 +39,28 @@ var (
 	// to the peer set, but one with the same id already exists.
 	errPeerAlreadyRegistered = errors.New("peer already registered")
 
+	// errPeerWaitTimeout is returned if a peer waits extension for too long
+	errPeerWaitTimeout = errors.New("peer wait timeout")
+
 	// errPeerNotRegistered is returned if a peer is attempted to be removed from
 	// a peer set, but no peer with the given id exists.
 	errPeerNotRegistered = errors.New("peer not registered")
 
 	// errSnapWithoutEth is returned if a peer attempts to connect only on the
-	// snap protocol without advertizing the eth main protocol.
+	// snap protocol without advertising the eth main protocol.
 	errSnapWithoutEth = errors.New("peer connected on snap without compatible eth support")
 
 	// errDiffWithoutEth is returned if a peer attempts to connect only on the
-	// diff protocol without advertizing the eth main protocol.
+	// diff protocol without advertising the eth main protocol.
 	errDiffWithoutEth = errors.New("peer connected on diff without compatible eth support")
 )
 
+const (
+	// extensionWaitTimeout is the maximum allowed time for the extension wait to
+	// complete before dropping the connection as malicious.
+	extensionWaitTimeout = 10 * time.Second
+)
+
 // peerSet represents the collection of active peers currently participating in
 // the `eth` protocol, with or without the `snap` extension.
 type peerSet struct {
@@ -169,7 +179,16 @@ func (ps *peerSet) waitSnapExtension(peer *eth.Peer) (*snap.Peer, error) {
 	ps.snapWait[id] = wait
 	ps.lock.Unlock()
 
-	return <-wait, nil
+	select {
+	case peer := <-wait:
+		return peer, nil
+
+	case <-time.After(extensionWaitTimeout):
+		ps.lock.Lock()
+		delete(ps.snapWait, id)
+		ps.lock.Unlock()
+		return nil, errPeerWaitTimeout
+	}
 }
 
 // waitDiffExtension blocks until all satellite protocols are connected and tracked
@@ -203,7 +222,16 @@ func (ps *peerSet) waitDiffExtension(peer *eth.Peer) (*diff.Peer, error) {
 	ps.diffWait[id] = wait
 	ps.lock.Unlock()
 
-	return <-wait, nil
+	select {
+	case peer := <-wait:
+		return peer, nil
+
+	case <-time.After(extensionWaitTimeout):
+		ps.lock.Lock()
+		delete(ps.diffWait, id)
+		ps.lock.Unlock()
+		return nil, errPeerWaitTimeout
+	}
 }
 
 func (ps *peerSet) GetDiffPeer(pid string) downloader.IDiffPeer {

+ 1 - 1
eth/protocols/diff/handshake.go

@@ -26,7 +26,7 @@ import (
 
 const (
 	// handshakeTimeout is the maximum allowed time for the `diff` handshake to
-	// complete before dropping the connection.= as malicious.
+	// complete before dropping the connection as malicious.
 	handshakeTimeout = 5 * time.Second
 )
 

+ 10 - 0
p2p/peer.go

@@ -129,6 +129,16 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer {
 	return peer
 }
 
+// NewPeerWithProtocols returns a peer for testing purposes.
+func NewPeerWithProtocols(id enode.ID, protocols []Protocol, name string, caps []Cap) *Peer {
+	pipe, _ := net.Pipe()
+	node := enode.SignNull(new(enr.Record), id)
+	conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
+	peer := newPeer(log.Root(), conn, protocols)
+	close(peer.closed) // ensures Disconnect doesn't block
+	return peer
+}
+
 // ID returns the node's public key.
 func (p *Peer) ID() enode.ID {
 	return p.rw.node.ID()

+ 15 - 1
p2p/server.go

@@ -63,6 +63,9 @@ const (
 
 	// Maximum amount of time allowed for writing a complete message.
 	frameWriteTimeout = 20 * time.Second
+
+	// Maximum time to wait before stop the p2p server
+	stopTimeout = 5 * time.Second
 )
 
 var errServerStopped = errors.New("server stopped")
@@ -403,7 +406,18 @@ func (srv *Server) Stop() {
 	}
 	close(srv.quit)
 	srv.lock.Unlock()
-	srv.loopWG.Wait()
+
+	stopChan := make(chan struct{})
+	go func() {
+		srv.loopWG.Wait()
+		close(stopChan)
+	}()
+
+	select {
+	case <-stopChan:
+	case <-time.After(stopTimeout):
+		srv.log.Warn("stop p2p server timeout, forcing stop")
+	}
 }
 
 // sharedUDPConn implements a shared connection. Write sends messages to the underlying connection while read returns

+ 23 - 0
p2p/server_test.go

@@ -203,6 +203,29 @@ func TestServerDial(t *testing.T) {
 	}
 }
 
+func TestServerStopTimeout(t *testing.T) {
+	srv := &Server{Config: Config{
+		PrivateKey:  newkey(),
+		MaxPeers:    1,
+		NoDiscovery: true,
+		Logger:      testlog.Logger(t, log.LvlTrace).New("server", "1"),
+	}}
+	srv.Start()
+	srv.loopWG.Add(1)
+
+	stopChan := make(chan struct{})
+	go func() {
+		srv.Stop()
+		close(stopChan)
+	}()
+
+	select {
+	case <-stopChan:
+	case <-time.After(10 * time.Second):
+		t.Error("server should be shutdown in 10 seconds")
+	}
+}
+
 // This test checks that RemovePeer disconnects the peer if it is connected.
 func TestServerRemovePeerDisconnect(t *testing.T) {
 	srv1 := &Server{Config: Config{