Browse Source

cmd/devp2p/internal/ethtest: add more tx propagation tests (#22630)

This adds a test for large tx announcement messages, as well as a test to
check that announced tx hashes are requested by the node.
rene 4 years ago
parent
commit
cac1b21d39

+ 81 - 16
cmd/devp2p/internal/ethtest/eth66_suite.go

@@ -19,6 +19,7 @@ package ethtest
 import (
 	"time"
 
+	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/eth/protocols/eth"
@@ -125,22 +126,7 @@ func (s *Suite) TestSimultaneousRequests_66(t *utesting.T) {
 // TestBroadcast_66 tests whether a block announcement is correctly
 // propagated to the given node's peer(s) on the eth66 protocol.
 func (s *Suite) TestBroadcast_66(t *utesting.T) {
-	sendConn, receiveConn := s.setupConnection66(t), s.setupConnection66(t)
-	defer sendConn.Close()
-	defer receiveConn.Close()
-
-	nextBlock := len(s.chain.blocks)
-	blockAnnouncement := &NewBlock{
-		Block: s.fullChain.blocks[nextBlock],
-		TD:    s.fullChain.TD(nextBlock + 1),
-	}
-	s.testAnnounce66(t, sendConn, receiveConn, blockAnnouncement)
-	// update test suite chain
-	s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock])
-	// wait for client to update its chain
-	if err := receiveConn.waitForBlock66(s.chain.Head()); err != nil {
-		t.Fatal(err)
-	}
+	s.sendNextBlock66(t)
 }
 
 // TestGetBlockBodies_66 tests whether the given node can respond to
@@ -426,3 +412,82 @@ func (s *Suite) TestSameRequestID_66(t *utesting.T) {
 	// check response from first request
 	headersMatch(t, s.chain, s.getBlockHeaders66(t, conn, req1, reqID))
 }
+
+// TestLargeTxRequest_66 tests whether a node can fulfill a large GetPooledTransactions
+// request.
+func (s *Suite) TestLargeTxRequest_66(t *utesting.T) {
+	// send the next block to ensure the node is no longer syncing and is able to accept
+	// txs
+	s.sendNextBlock66(t)
+	// send 2000 transactions to the node
+	hashMap, txs := generateTxs(t, s, 2000)
+	sendConn := s.setupConnection66(t)
+	defer sendConn.Close()
+
+	sendMultipleSuccessfulTxs(t, s, sendConn, txs)
+	// set up connection to receive to ensure node is peered with the receiving connection
+	// before tx request is sent
+	recvConn := s.setupConnection66(t)
+	defer recvConn.Close()
+	// create and send pooled tx request
+	hashes := make([]common.Hash, 0)
+	for _, hash := range hashMap {
+		hashes = append(hashes, hash)
+	}
+	getTxReq := &eth.GetPooledTransactionsPacket66{
+		RequestId:                   1234,
+		GetPooledTransactionsPacket: hashes,
+	}
+	if err := recvConn.write66(getTxReq, GetPooledTransactions{}.Code()); err != nil {
+		t.Fatalf("could not write to conn: %v", err)
+	}
+	// check that all received transactions match those that were sent to node
+	switch msg := recvConn.waitForResponse(s.chain, timeout, getTxReq.RequestId).(type) {
+	case PooledTransactions:
+		for _, gotTx := range msg {
+			if _, exists := hashMap[gotTx.Hash()]; !exists {
+				t.Fatalf("unexpected tx received: %v", gotTx.Hash())
+			}
+		}
+	default:
+		t.Fatalf("unexpected %s", pretty.Sdump(msg))
+	}
+}
+
+// TestNewPooledTxs_66 tests whether a node will do a GetPooledTransactions
+// request upon receiving a NewPooledTransactionHashes announcement.
+func (s *Suite) TestNewPooledTxs_66(t *utesting.T) {
+	// send the next block to ensure the node is no longer syncing and is able to accept
+	// txs
+	s.sendNextBlock66(t)
+	// generate 50 txs
+	hashMap, _ := generateTxs(t, s, 50)
+	// create new pooled tx hashes announcement
+	hashes := make([]common.Hash, 0)
+	for _, hash := range hashMap {
+		hashes = append(hashes, hash)
+	}
+	announce := NewPooledTransactionHashes(hashes)
+	// send announcement
+	conn := s.setupConnection66(t)
+	defer conn.Close()
+	if err := conn.Write(announce); err != nil {
+		t.Fatalf("could not write to connection: %v", err)
+	}
+	// wait for GetPooledTxs request
+	for {
+		_, msg := conn.readAndServe66(s.chain, timeout)
+		switch msg := msg.(type) {
+		case GetPooledTransactions:
+			if len(msg) != len(hashes) {
+				t.Fatalf("unexpected number of txs requested: wanted %d, got %d", len(hashes), len(msg))
+			}
+			return
+		case *NewPooledTransactionHashes:
+			// ignore propagated txs from old tests
+			continue
+		default:
+			t.Fatalf("unexpected %s", pretty.Sdump(msg))
+		}
+	}
+}

+ 69 - 21
cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go

@@ -111,6 +111,18 @@ func (c *Conn) read66() (uint64, Message) {
 		msg = new(Transactions)
 	case (NewPooledTransactionHashes{}).Code():
 		msg = new(NewPooledTransactionHashes)
+	case (GetPooledTransactions{}.Code()):
+		ethMsg := new(eth.GetPooledTransactionsPacket66)
+		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
+			return 0, errorf("could not rlp decode message: %v", err)
+		}
+		return ethMsg.RequestId, GetPooledTransactions(ethMsg.GetPooledTransactionsPacket)
+	case (PooledTransactions{}.Code()):
+		ethMsg := new(eth.PooledTransactionsPacket66)
+		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
+			return 0, errorf("could not rlp decode message: %v", err)
+		}
+		return ethMsg.RequestId, PooledTransactions(ethMsg.PooledTransactionsPacket)
 	default:
 		msg = errorf("invalid message code: %d", code)
 	}
@@ -124,6 +136,15 @@ func (c *Conn) read66() (uint64, Message) {
 	return 0, errorf("invalid message: %s", string(rawData))
 }
 
+func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message {
+	for {
+		id, msg := c.readAndServe66(chain, timeout)
+		if id == requestID {
+			return msg
+		}
+	}
+}
+
 // ReadAndServe serves GetBlockHeaders requests while waiting
 // on another message from the node.
 func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) {
@@ -173,27 +194,33 @@ func (s *Suite) testAnnounce66(t *utesting.T, sendConn, receiveConn *Conn, block
 }
 
 func (s *Suite) waitAnnounce66(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) {
-	timeout := 20 * time.Second
-	_, msg := conn.readAndServe66(s.chain, timeout)
-	switch msg := msg.(type) {
-	case *NewBlock:
-		t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block))
-		assert.Equal(t,
-			blockAnnouncement.Block.Header(), msg.Block.Header(),
-			"wrong block header in announcement",
-		)
-		assert.Equal(t,
-			blockAnnouncement.TD, msg.TD,
-			"wrong TD in announcement",
-		)
-	case *NewBlockHashes:
-		blockHashes := *msg
-		t.Logf("received NewBlockHashes message: %s", pretty.Sdump(blockHashes))
-		assert.Equal(t, blockAnnouncement.Block.Hash(), blockHashes[0].Hash,
-			"wrong block hash in announcement",
-		)
-	default:
-		t.Fatalf("unexpected: %s", pretty.Sdump(msg))
+	for {
+		_, msg := conn.readAndServe66(s.chain, timeout)
+		switch msg := msg.(type) {
+		case *NewBlock:
+			t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block))
+			assert.Equal(t,
+				blockAnnouncement.Block.Header(), msg.Block.Header(),
+				"wrong block header in announcement",
+			)
+			assert.Equal(t,
+				blockAnnouncement.TD, msg.TD,
+				"wrong TD in announcement",
+			)
+			return
+		case *NewBlockHashes:
+			blockHashes := *msg
+			t.Logf("received NewBlockHashes message: %s", pretty.Sdump(blockHashes))
+			assert.Equal(t, blockAnnouncement.Block.Hash(), blockHashes[0].Hash,
+				"wrong block hash in announcement",
+			)
+			return
+		case *NewPooledTransactionHashes:
+			// ignore old txs being propagated
+			continue
+		default:
+			t.Fatalf("unexpected: %s", pretty.Sdump(msg))
+		}
 	}
 }
 
@@ -268,3 +295,24 @@ func headersMatch(t *utesting.T, chain *Chain, headers BlockHeaders) {
 		assert.Equal(t, chain.blocks[int(num)].Header(), header)
 	}
 }
+
+func (s *Suite) sendNextBlock66(t *utesting.T) {
+	sendConn, receiveConn := s.setupConnection66(t), s.setupConnection66(t)
+	defer sendConn.Close()
+	defer receiveConn.Close()
+
+	// create new block announcement
+	nextBlock := len(s.chain.blocks)
+	blockAnnouncement := &NewBlock{
+		Block: s.fullChain.blocks[nextBlock],
+		TD:    s.fullChain.TD(nextBlock + 1),
+	}
+	// send announcement and wait for node to request the header
+	s.testAnnounce66(t, sendConn, receiveConn, blockAnnouncement)
+	// update test suite chain
+	s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock])
+	// wait for client to update its chain
+	if err := receiveConn.waitForBlock66(s.chain.Head()); err != nil {
+		t.Fatal(err)
+	}
+}

+ 4 - 1
cmd/devp2p/internal/ethtest/suite.go

@@ -97,6 +97,8 @@ func (s *Suite) AllEthTests() []utesting.Test {
 		{Name: "TestTransaction_66", Fn: s.TestTransaction_66},
 		{Name: "TestMaliciousTx", Fn: s.TestMaliciousTx},
 		{Name: "TestMaliciousTx_66", Fn: s.TestMaliciousTx_66},
+		{Name: "TestLargeTxRequest_66", Fn: s.TestLargeTxRequest_66},
+		{Name: "TestNewPooledTxs_66", Fn: s.TestNewPooledTxs_66},
 	}
 }
 
@@ -129,6 +131,8 @@ func (s *Suite) Eth66Tests() []utesting.Test {
 		{Name: "TestMaliciousStatus_66", Fn: s.TestMaliciousStatus_66},
 		{Name: "TestTransaction_66", Fn: s.TestTransaction_66},
 		{Name: "TestMaliciousTx_66", Fn: s.TestMaliciousTx_66},
+		{Name: "TestLargeTxRequest_66", Fn: s.TestLargeTxRequest_66},
+		{Name: "TestNewPooledTxs_66", Fn: s.TestNewPooledTxs_66},
 	}
 }
 
@@ -455,7 +459,6 @@ func (s *Suite) testAnnounce(t *utesting.T, sendConn, receiveConn *Conn, blockAn
 }
 
 func (s *Suite) waitAnnounce(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) {
-	timeout := 20 * time.Second
 	switch msg := conn.ReadAndServe(s.chain, timeout).(type) {
 	case *NewBlock:
 		t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block))

+ 112 - 20
cmd/devp2p/internal/ethtest/transaction.go

@@ -17,12 +17,15 @@
 package ethtest
 
 import (
+	"math/big"
+	"strings"
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/internal/utesting"
+	"github.com/ethereum/go-ethereum/params"
 )
 
 //var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7")
@@ -40,7 +43,9 @@ func sendSuccessfulTxWithConn(t *utesting.T, s *Suite, tx *types.Transaction, se
 	if err := sendConn.Write(&Transactions{tx}); err != nil {
 		t.Fatal(err)
 	}
-	time.Sleep(100 * time.Millisecond)
+	// update last nonce seen
+	nonce = tx.Nonce()
+
 	recvConn := s.setupConnection(t)
 	// Wait for the transaction announcement
 	switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) {
@@ -66,6 +71,60 @@ func sendSuccessfulTxWithConn(t *utesting.T, s *Suite, tx *types.Transaction, se
 	}
 }
 
+var nonce = uint64(99)
+
+func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*types.Transaction) {
+	txMsg := Transactions(txs)
+	t.Logf("sending %d txs\n", len(txs))
+
+	recvConn := s.setupConnection(t)
+	defer recvConn.Close()
+
+	// Send the transactions
+	if err := sendConn.Write(&txMsg); err != nil {
+		t.Fatal(err)
+	}
+	// update nonce
+	nonce = txs[len(txs)-1].Nonce()
+	// Wait for the transaction announcement(s) and make sure all sent txs are being propagated
+	recvHashes := make([]common.Hash, 0)
+	// all txs should be announced within 3 announcements
+	for i := 0; i < 3; i++ {
+		switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) {
+		case *Transactions:
+			for _, tx := range *msg {
+				recvHashes = append(recvHashes, tx.Hash())
+			}
+		case *NewPooledTransactionHashes:
+			recvHashes = append(recvHashes, *msg...)
+		default:
+			if !strings.Contains(pretty.Sdump(msg), "i/o timeout") {
+				t.Fatalf("unexpected message while waiting to receive txs: %s", pretty.Sdump(msg))
+			}
+		}
+		// break once all 2000 txs have been received
+		if len(recvHashes) == 2000 {
+			break
+		}
+		if len(recvHashes) > 0 {
+			_, missingTxs := compareReceivedTxs(recvHashes, txs)
+			if len(missingTxs) > 0 {
+				continue
+			} else {
+				t.Logf("successfully received all %d txs", len(txs))
+				return
+			}
+		}
+	}
+	_, missingTxs := compareReceivedTxs(recvHashes, txs)
+	if len(missingTxs) > 0 {
+		for _, missing := range missingTxs {
+			t.Logf("missing tx: %v", missing.Hash())
+		}
+		t.Fatalf("missing %d txs", len(missingTxs))
+	}
+}
+
 func waitForTxPropagation(t *utesting.T, s *Suite, txs []*types.Transaction, recvConn *Conn) {
 	// Wait for another transaction announcement
 	switch msg := recvConn.ReadAndServe(s.chain, time.Second*8).(type) {
@@ -75,7 +134,7 @@ func waitForTxPropagation(t *utesting.T, s *Suite, txs []*types.Transaction, rec
 		for i, recvTx := range *msg {
 			recvTxs[i] = recvTx.Hash()
 		}
-		badTxs := containsTxs(recvTxs, txs)
+		badTxs, _ := compareReceivedTxs(recvTxs, txs)
 		if len(badTxs) > 0 {
 			for _, tx := range badTxs {
 				t.Logf("received bad tx: %v", tx)
@@ -83,7 +142,7 @@ func waitForTxPropagation(t *utesting.T, s *Suite, txs []*types.Transaction, rec
 			t.Fatalf("received %d bad txs", len(badTxs))
 		}
 	case *NewPooledTransactionHashes:
-		badTxs := containsTxs(*msg, txs)
+		badTxs, _ := compareReceivedTxs(*msg, txs)
 		if len(badTxs) > 0 {
 			for _, tx := range badTxs {
 				t.Logf("received bad tx: %v", tx)
@@ -98,18 +157,27 @@ func waitForTxPropagation(t *utesting.T, s *Suite, txs []*types.Transaction, rec
 	}
 }
 
-// containsTxs checks whether the hashes of the received transactions are present in
-// the given set of txs
-func containsTxs(recvTxs []common.Hash, txs []*types.Transaction) []common.Hash {
-	containedTxs := make([]common.Hash, 0)
-	for _, recvTx := range recvTxs {
-		for _, tx := range txs {
-			if recvTx == tx.Hash() {
-				containedTxs = append(containedTxs, recvTx)
-			}
+// compareReceivedTxs compares the received set of txs against the given set of txs,
+// returning both the set received txs that were present within the given txs, and
+// the set of txs that were missing from the set of received txs
+func compareReceivedTxs(recvTxs []common.Hash, txs []*types.Transaction) (present []*types.Transaction, missing []*types.Transaction) {
+	// create a map of the hashes received from node
+	recvHashes := make(map[common.Hash]common.Hash)
+	for _, hash := range recvTxs {
+		recvHashes[hash] = hash
+	}
+
+	// collect present txs and missing txs separately
+	present = make([]*types.Transaction, 0)
+	missing = make([]*types.Transaction, 0)
+	for _, tx := range txs {
+		if _, exists := recvHashes[tx.Hash()]; exists {
+			present = append(present, tx)
+		} else {
+			missing = append(missing, tx)
 		}
 	}
-	return containedTxs
+	return present, missing
 }
 
 func unknownTx(t *utesting.T, s *Suite) *types.Transaction {
@@ -119,7 +187,7 @@ func unknownTx(t *utesting.T, s *Suite) *types.Transaction {
 		to = *tx.To()
 	}
 	txNew := types.NewTransaction(tx.Nonce()+1, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data())
-	return signWithFaucet(t, txNew)
+	return signWithFaucet(t, s.chain.chainConfig, txNew)
 }
 
 func getNextTxFromChain(t *utesting.T, s *Suite) *types.Transaction {
@@ -138,6 +206,30 @@ func getNextTxFromChain(t *utesting.T, s *Suite) *types.Transaction {
 	return tx
 }
 
+func generateTxs(t *utesting.T, s *Suite, numTxs int) (map[common.Hash]common.Hash, []*types.Transaction) {
+	txHashMap := make(map[common.Hash]common.Hash, numTxs)
+	txs := make([]*types.Transaction, numTxs)
+
+	nextTx := getNextTxFromChain(t, s)
+	gas := nextTx.Gas()
+
+	nonce = nonce + 1
+	// generate txs
+	for i := 0; i < numTxs; i++ {
+		tx := generateTx(t, s.chain.chainConfig, nonce, gas)
+		txHashMap[tx.Hash()] = tx.Hash()
+		txs[i] = tx
+		nonce = nonce + 1
+	}
+	return txHashMap, txs
+}
+
+func generateTx(t *utesting.T, chainConfig *params.ChainConfig, nonce uint64, gas uint64) *types.Transaction {
+	var to common.Address
+	tx := types.NewTransaction(nonce, to, big.NewInt(1), gas, big.NewInt(1), []byte{})
+	return signWithFaucet(t, chainConfig, tx)
+}
+
 func getOldTxFromChain(t *utesting.T, s *Suite) *types.Transaction {
 	var tx *types.Transaction
 	for _, blocks := range s.fullChain.blocks[:s.chain.Len()-1] {
@@ -160,7 +252,7 @@ func invalidNonceTx(t *utesting.T, s *Suite) *types.Transaction {
 		to = *tx.To()
 	}
 	txNew := types.NewTransaction(tx.Nonce()-2, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data())
-	return signWithFaucet(t, txNew)
+	return signWithFaucet(t, s.chain.chainConfig, txNew)
 }
 
 func hugeAmount(t *utesting.T, s *Suite) *types.Transaction {
@@ -171,7 +263,7 @@ func hugeAmount(t *utesting.T, s *Suite) *types.Transaction {
 		to = *tx.To()
 	}
 	txNew := types.NewTransaction(tx.Nonce(), to, amount, tx.Gas(), tx.GasPrice(), tx.Data())
-	return signWithFaucet(t, txNew)
+	return signWithFaucet(t, s.chain.chainConfig, txNew)
 }
 
 func hugeGasPrice(t *utesting.T, s *Suite) *types.Transaction {
@@ -182,7 +274,7 @@ func hugeGasPrice(t *utesting.T, s *Suite) *types.Transaction {
 		to = *tx.To()
 	}
 	txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), gasPrice, tx.Data())
-	return signWithFaucet(t, txNew)
+	return signWithFaucet(t, s.chain.chainConfig, txNew)
 }
 
 func hugeData(t *utesting.T, s *Suite) *types.Transaction {
@@ -192,11 +284,11 @@ func hugeData(t *utesting.T, s *Suite) *types.Transaction {
 		to = *tx.To()
 	}
 	txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), tx.GasPrice(), largeBuffer(2))
-	return signWithFaucet(t, txNew)
+	return signWithFaucet(t, s.chain.chainConfig, txNew)
 }
 
-func signWithFaucet(t *utesting.T, tx *types.Transaction) *types.Transaction {
-	signer := types.HomesteadSigner{}
+func signWithFaucet(t *utesting.T, chainConfig *params.ChainConfig, tx *types.Transaction) *types.Transaction {
+	signer := types.LatestSigner(chainConfig)
 	signedTx, err := types.SignTx(tx, signer, faucetKey)
 	if err != nil {
 		t.Fatalf("could not sign tx: %v\n", err)

+ 12 - 0
cmd/devp2p/internal/ethtest/types.go

@@ -120,6 +120,14 @@ type NewPooledTransactionHashes eth.NewPooledTransactionHashesPacket
 
 func (nb NewPooledTransactionHashes) Code() int { return 24 }
 
+type GetPooledTransactions eth.GetPooledTransactionsPacket
+
+func (gpt GetPooledTransactions) Code() int { return 25 }
+
+type PooledTransactions eth.PooledTransactionsPacket
+
+func (pt PooledTransactions) Code() int { return 26 }
+
 // Conn represents an individual connection with a peer
 type Conn struct {
 	*rlpx.Conn
@@ -163,6 +171,10 @@ func (c *Conn) Read() Message {
 		msg = new(Transactions)
 	case (NewPooledTransactionHashes{}).Code():
 		msg = new(NewPooledTransactionHashes)
+	case (GetPooledTransactions{}.Code()):
+		msg = new(GetPooledTransactions)
+	case (PooledTransactions{}.Code()):
+		msg = new(PooledTransactions)
 	default:
 		return errorf("invalid message code: %d", code)
 	}