Browse Source

p2p: don't send DiscReason when using net.Pipe (#16004)

Anton Evangelatov 7 năm trước cách đây
mục cha
commit
1e457b6599
2 tập tin đã thay đổi với 43 bổ sung5 xóa
  1. 8 2
      p2p/rlpx.go
  2. 35 3
      p2p/rlpx_test.go

+ 8 - 2
p2p/rlpx.go

@@ -108,8 +108,14 @@ func (t *rlpx) close(err error) {
 	// Tell the remote end why we're disconnecting if possible.
 	if t.rw != nil {
 		if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
-			t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout))
-			SendItems(t.rw, discMsg, r)
+			// rlpx tries to send DiscReason to disconnected peer
+			// if the connection is net.Pipe (in-memory simulation)
+			// it hangs forever, since net.Pipe does not implement
+			// a write deadline. Because of this only try to send
+			// the disconnect reason message if there is no error.
+			if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil {
+				SendItems(t.rw, discMsg, r)
+			}
 		}
 	}
 	t.fd.Close()

+ 35 - 3
p2p/rlpx_test.go

@@ -156,14 +156,18 @@ func TestProtocolHandshake(t *testing.T) {
 		node1   = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
 		hs1     = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
 
-		fd0, fd1 = net.Pipe()
-		wg       sync.WaitGroup
+		wg sync.WaitGroup
 	)
 
+	fd0, fd1, err := tcpPipe()
+	if err != nil {
+		t.Fatal(err)
+	}
+
 	wg.Add(2)
 	go func() {
 		defer wg.Done()
-		defer fd1.Close()
+		defer fd0.Close()
 		rlpx := newRLPX(fd0)
 		remid, err := rlpx.doEncHandshake(prv0, node1)
 		if err != nil {
@@ -597,3 +601,31 @@ func TestHandshakeForwardCompatibility(t *testing.T) {
 		t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash)
 	}
 }
+
+// tcpPipe creates an in process full duplex pipe based on a localhost TCP socket
+func tcpPipe() (net.Conn, net.Conn, error) {
+	l, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		return nil, nil, err
+	}
+	defer l.Close()
+
+	var aconn net.Conn
+	aerr := make(chan error, 1)
+	go func() {
+		var err error
+		aconn, err = l.Accept()
+		aerr <- err
+	}()
+
+	dconn, err := net.Dial("tcp", l.Addr().String())
+	if err != nil {
+		<-aerr
+		return nil, nil, err
+	}
+	if err := <-aerr; err != nil {
+		dconn.Close()
+		return nil, nil, err
+	}
+	return aconn, dconn, nil
+}