Kaynağa Gözat

cmd/devp2p: add eth protocol test suite (#21598)

This change adds a test framework for the "eth" protocol and some basic
tests. The tests can be run using the './devp2p rlpx eth-test' command.
rene 5 yıl önce
ebeveyn
işleme
a25899f3dc

+ 113 - 0
cmd/devp2p/internal/ethtest/chain.go

@@ -0,0 +1,113 @@
+package ethtest
+
+import (
+	"compress/gzip"
+	"encoding/json"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"math/big"
+	"os"
+	"strings"
+
+	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/core/forkid"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/params"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+type Chain struct {
+	blocks      []*types.Block
+	chainConfig *params.ChainConfig
+}
+
+func (c *Chain) WriteTo(writer io.Writer) error {
+	for _, block := range c.blocks {
+		if err := rlp.Encode(writer, block); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+// Len returns the length of the chain.
+func (c *Chain) Len() int {
+	return len(c.blocks)
+}
+
+// TD calculates the total difficulty of the chain.
+func (c *Chain) TD(height int) *big.Int { // TODO later on channge scheme so that the height is included in range
+	sum := big.NewInt(0)
+	for _, block := range c.blocks[:height] {
+		sum.Add(sum, block.Difficulty())
+	}
+	return sum
+}
+
+// ForkID gets the fork id of the chain.
+func (c *Chain) ForkID() forkid.ID {
+	return forkid.NewID(c.chainConfig, c.blocks[0].Hash(), uint64(c.Len()))
+}
+
+// Shorten returns a copy chain of a desired height from the imported
+func (c *Chain) Shorten(height int) *Chain {
+	blocks := make([]*types.Block, height)
+	copy(blocks, c.blocks[:height])
+
+	config := *c.chainConfig
+	return &Chain{
+		blocks:      blocks,
+		chainConfig: &config,
+	}
+}
+
+// Head returns the chain head.
+func (c *Chain) Head() *types.Block {
+	return c.blocks[c.Len()-1]
+}
+
+// loadChain takes the given chain.rlp file, and decodes and returns
+// the blocks from the file.
+func loadChain(chainfile string, genesis string) (*Chain, error) {
+	// Open the file handle and potentially unwrap the gzip stream
+	fh, err := os.Open(chainfile)
+	if err != nil {
+		return nil, err
+	}
+	defer fh.Close()
+
+	var reader io.Reader = fh
+	if strings.HasSuffix(chainfile, ".gz") {
+		if reader, err = gzip.NewReader(reader); err != nil {
+			return nil, err
+		}
+	}
+	stream := rlp.NewStream(reader, 0)
+	var blocks []*types.Block
+	for i := 0; ; i++ {
+		var b types.Block
+		if err := stream.Decode(&b); err == io.EOF {
+			break
+		} else if err != nil {
+			return nil, fmt.Errorf("at block %d: %v", i, err)
+		}
+		blocks = append(blocks, &b)
+	}
+
+	// Open the file handle and potentially unwrap the gzip stream
+	chainConfig, err := ioutil.ReadFile(genesis)
+	if err != nil {
+		return nil, err
+	}
+	var gen core.Genesis
+	if err := json.Unmarshal(chainConfig, &gen); err != nil {
+		return nil, err
+	}
+
+	return &Chain{
+		blocks:      blocks,
+		chainConfig: gen.Config,
+	}, nil
+}

+ 337 - 0
cmd/devp2p/internal/ethtest/suite.go

@@ -0,0 +1,337 @@
+package ethtest
+
+import (
+	"crypto/ecdsa"
+	"fmt"
+	"net"
+	"reflect"
+	"time"
+
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/internal/utesting"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/p2p/rlpx"
+	"github.com/ethereum/go-ethereum/rlp"
+	"github.com/stretchr/testify/assert"
+)
+
+// Suite represents a structure used to test the eth
+// protocol of a node(s).
+type Suite struct {
+	Dest *enode.Node
+
+	chain     *Chain
+	fullChain *Chain
+}
+
+type Conn struct {
+	*rlpx.Conn
+	ourKey *ecdsa.PrivateKey
+}
+
+func (c *Conn) Read() Message {
+	code, rawData, _, err := c.Conn.Read()
+	if err != nil {
+		return &Error{fmt.Errorf("could not read from connection: %v", err)}
+	}
+
+	var msg Message
+	switch int(code) {
+	case (Hello{}).Code():
+		msg = new(Hello)
+	case (Disconnect{}).Code():
+		msg = new(Disconnect)
+	case (Status{}).Code():
+		msg = new(Status)
+	case (GetBlockHeaders{}).Code():
+		msg = new(GetBlockHeaders)
+	case (BlockHeaders{}).Code():
+		msg = new(BlockHeaders)
+	case (GetBlockBodies{}).Code():
+		msg = new(GetBlockBodies)
+	case (BlockBodies{}).Code():
+		msg = new(BlockBodies)
+	case (NewBlock{}).Code():
+		msg = new(NewBlock)
+	case (NewBlockHashes{}).Code():
+		msg = new(NewBlockHashes)
+	default:
+		return &Error{fmt.Errorf("invalid message code: %d", code)}
+	}
+
+	if err := rlp.DecodeBytes(rawData, msg); err != nil {
+		return &Error{fmt.Errorf("could not rlp decode message: %v", err)}
+	}
+
+	return msg
+}
+
+func (c *Conn) Write(msg Message) error {
+	payload, err := rlp.EncodeToBytes(msg)
+	if err != nil {
+		return err
+	}
+	_, err = c.Conn.Write(uint64(msg.Code()), payload)
+	return err
+
+}
+
+// handshake checks to make sure a `HELLO` is received.
+func (c *Conn) handshake(t *utesting.T) Message {
+	// write protoHandshake to client
+	pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
+	ourHandshake := &Hello{
+		Version: 3,
+		Caps:    []p2p.Cap{{Name: "eth", Version: 64}, {Name: "eth", Version: 65}},
+		ID:      pub0,
+	}
+	if err := c.Write(ourHandshake); err != nil {
+		t.Fatalf("could not write to connection: %v", err)
+	}
+	// read protoHandshake from client
+	switch msg := c.Read().(type) {
+	case *Hello:
+		return msg
+	default:
+		t.Fatalf("bad handshake: %v", msg)
+		return nil
+	}
+}
+
+// statusExchange performs a `Status` message exchange with the given
+// node.
+func (c *Conn) statusExchange(t *utesting.T, chain *Chain) Message {
+	// read status message from client
+	var message Message
+	switch msg := c.Read().(type) {
+	case *Status:
+		if msg.Head != chain.blocks[chain.Len()-1].Hash() {
+			t.Fatalf("wrong head in status: %v", msg.Head)
+		}
+		if msg.TD.Cmp(chain.TD(chain.Len())) != 0 {
+			t.Fatalf("wrong TD in status: %v", msg.TD)
+		}
+		if !reflect.DeepEqual(msg.ForkID, chain.ForkID()) {
+			t.Fatalf("wrong fork ID in status: %v", msg.ForkID)
+		}
+		message = msg
+	default:
+		t.Fatalf("bad status message: %v", msg)
+	}
+	// write status message to client
+	status := Status{
+		ProtocolVersion: 65,
+		NetworkID:       1,
+		TD:              chain.TD(chain.Len()),
+		Head:            chain.blocks[chain.Len()-1].Hash(),
+		Genesis:         chain.blocks[0].Hash(),
+		ForkID:          chain.ForkID(),
+	}
+	if err := c.Write(status); err != nil {
+		t.Fatalf("could not write to connection: %v", err)
+	}
+
+	return message
+}
+
+// waitForBlock waits for confirmation from the client that it has
+// imported the given block.
+func (c *Conn) waitForBlock(block *types.Block) error {
+	for {
+		req := &GetBlockHeaders{Origin: hashOrNumber{Hash: block.Hash()}, Amount: 1}
+		if err := c.Write(req); err != nil {
+			return err
+		}
+
+		switch msg := c.Read().(type) {
+		case *BlockHeaders:
+			if len(*msg) > 0 {
+				return nil
+			}
+			time.Sleep(100 * time.Millisecond)
+		default:
+			return fmt.Errorf("invalid message: %v", msg)
+		}
+	}
+}
+
+// NewSuite creates and returns a new eth-test suite that can
+// be used to test the given node against the given blockchain
+// data.
+func NewSuite(dest *enode.Node, chainfile string, genesisfile string) *Suite {
+	chain, err := loadChain(chainfile, genesisfile)
+	if err != nil {
+		panic(err)
+	}
+	return &Suite{
+		Dest:      dest,
+		chain:     chain.Shorten(1000),
+		fullChain: chain,
+	}
+}
+
+func (s *Suite) AllTests() []utesting.Test {
+	return []utesting.Test{
+		{Name: "Status", Fn: s.TestStatus},
+		{Name: "GetBlockHeaders", Fn: s.TestGetBlockHeaders},
+		{Name: "Broadcast", Fn: s.TestBroadcast},
+		{Name: "GetBlockBodies", Fn: s.TestGetBlockBodies},
+	}
+}
+
+// TestStatus attempts to connect to the given node and exchange
+// a status message with it, and then check to make sure
+// the chain head is correct.
+func (s *Suite) TestStatus(t *utesting.T) {
+	conn, err := s.dial()
+	if err != nil {
+		t.Fatalf("could not dial: %v", err)
+	}
+	// get protoHandshake
+	conn.handshake(t)
+	// get status
+	switch msg := conn.statusExchange(t, s.chain).(type) {
+	case *Status:
+		t.Logf("%+v\n", msg)
+	default:
+		t.Fatalf("error: %v", msg)
+	}
+}
+
+// TestGetBlockHeaders tests whether the given node can respond to
+// a `GetBlockHeaders` request and that the response is accurate.
+func (s *Suite) TestGetBlockHeaders(t *utesting.T) {
+	conn, err := s.dial()
+	if err != nil {
+		t.Fatalf("could not dial: %v", err)
+	}
+
+	conn.handshake(t)
+	conn.statusExchange(t, s.chain)
+
+	// get block headers
+	req := &GetBlockHeaders{
+		Origin: hashOrNumber{
+			Hash: s.chain.blocks[1].Hash(),
+		},
+		Amount:  2,
+		Skip:    1,
+		Reverse: false,
+	}
+
+	if err := conn.Write(req); err != nil {
+		t.Fatalf("could not write to connection: %v", err)
+	}
+
+	switch msg := conn.Read().(type) {
+	case *BlockHeaders:
+		headers := msg
+		for _, header := range *headers {
+			num := header.Number.Uint64()
+			assert.Equal(t, s.chain.blocks[int(num)].Header(), header)
+			t.Logf("\nHEADER FOR BLOCK NUMBER %d: %+v\n", header.Number, header)
+		}
+	default:
+		t.Fatalf("error: %v", msg)
+	}
+}
+
+// TestGetBlockBodies tests whether the given node can respond to
+// a `GetBlockBodies` request and that the response is accurate.
+func (s *Suite) TestGetBlockBodies(t *utesting.T) {
+	conn, err := s.dial()
+	if err != nil {
+		t.Fatalf("could not dial: %v", err)
+	}
+
+	conn.handshake(t)
+	conn.statusExchange(t, s.chain)
+	// create block bodies request
+	req := &GetBlockBodies{s.chain.blocks[54].Hash(), s.chain.blocks[75].Hash()}
+	if err := conn.Write(req); err != nil {
+		t.Fatalf("could not write to connection: %v", err)
+	}
+
+	switch msg := conn.Read().(type) {
+	case *BlockBodies:
+		bodies := msg
+		for _, body := range *bodies {
+			t.Logf("\nBODY: %+v\n", body)
+		}
+	default:
+		t.Fatalf("error: %v", msg)
+	}
+}
+
+// TestBroadcast tests whether a block announcement is correctly
+// propagated to the given node's peer(s).
+func (s *Suite) TestBroadcast(t *utesting.T) {
+	// create conn to send block announcement
+	sendConn, err := s.dial()
+	if err != nil {
+		t.Fatalf("could not dial: %v", err)
+	}
+	// create conn to receive block announcement
+	receiveConn, err := s.dial()
+	if err != nil {
+		t.Fatalf("could not dial: %v", err)
+	}
+
+	sendConn.handshake(t)
+	receiveConn.handshake(t)
+
+	sendConn.statusExchange(t, s.chain)
+	receiveConn.statusExchange(t, s.chain)
+
+	// sendConn sends the block announcement
+	blockAnnouncement := &NewBlock{
+		Block: s.fullChain.blocks[1000],
+		TD:    s.fullChain.TD(1001),
+	}
+	if err := sendConn.Write(blockAnnouncement); err != nil {
+		t.Fatalf("could not write to connection: %v", err)
+	}
+
+	switch msg := receiveConn.Read().(type) {
+	case *NewBlock:
+		assert.Equal(t, blockAnnouncement.Block.Header(), msg.Block.Header(),
+			"wrong block header in announcement")
+		assert.Equal(t, blockAnnouncement.TD, msg.TD,
+			"wrong TD in announcement")
+	case *NewBlockHashes:
+		hashes := *msg
+		assert.Equal(t, blockAnnouncement.Block.Hash(), hashes[0].Hash,
+			"wrong block hash in announcement")
+	default:
+		t.Fatal(msg)
+	}
+	// update test suite chain
+	s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[1000])
+	// wait for client to update its chain
+	if err := receiveConn.waitForBlock(s.chain.Head()); err != nil {
+		t.Fatal(err)
+	}
+}
+
+// dial attempts to dial the given node and perform a handshake,
+// returning the created Conn if successful.
+func (s *Suite) dial() (*Conn, error) {
+	var conn Conn
+
+	fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
+	if err != nil {
+		return nil, err
+	}
+	conn.Conn = rlpx.NewConn(fd, s.Dest.Pubkey())
+
+	// do encHandshake
+	conn.ourKey, _ = crypto.GenerateKey()
+	_, err = conn.Handshake(conn.ourKey)
+	if err != nil {
+		return nil, err
+	}
+
+	return &conn, nil
+}

+ 134 - 0
cmd/devp2p/internal/ethtest/types.go

@@ -0,0 +1,134 @@
+package ethtest
+
+import (
+	"fmt"
+	"io"
+	"math/big"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/core/forkid"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+type Message interface {
+	Code() int
+}
+
+type Error struct {
+	err error
+}
+
+func (e *Error) Unwrap() error { return e.err }
+func (e *Error) Error() string { return e.err.Error() }
+func (e *Error) Code() int     { return -1 }
+
+// Hello is the RLP structure of the protocol handshake.
+type Hello struct {
+	Version    uint64
+	Name       string
+	Caps       []p2p.Cap
+	ListenPort uint64
+	ID         []byte // secp256k1 public key
+
+	// Ignore additional fields (for forward compatibility).
+	Rest []rlp.RawValue `rlp:"tail"`
+}
+
+func (h Hello) Code() int { return 0x00 }
+
+// Disconnect is the RLP structure for a disconnect message.
+type Disconnect struct {
+	Reason p2p.DiscReason
+}
+
+func (d Disconnect) Code() int { return 0x01 }
+
+// Status is the network packet for the status message for eth/64 and later.
+type Status struct {
+	ProtocolVersion uint32
+	NetworkID       uint64
+	TD              *big.Int
+	Head            common.Hash
+	Genesis         common.Hash
+	ForkID          forkid.ID
+}
+
+func (s Status) Code() int { return 16 }
+
+// NewBlockHashes is the network packet for the block announcements.
+type NewBlockHashes []struct {
+	Hash   common.Hash // Hash of one particular block being announced
+	Number uint64      // Number of one particular block being announced
+}
+
+func (nbh NewBlockHashes) Code() int { return 17 }
+
+// NewBlock is the network packet for the block propagation message.
+type NewBlock struct {
+	Block *types.Block
+	TD    *big.Int
+}
+
+func (nb NewBlock) Code() int { return 23 }
+
+// GetBlockHeaders represents a block header query.
+type GetBlockHeaders struct {
+	Origin  hashOrNumber // Block from which to retrieve headers
+	Amount  uint64       // Maximum number of headers to retrieve
+	Skip    uint64       // Blocks to skip between consecutive headers
+	Reverse bool         // Query direction (false = rising towards latest, true = falling towards genesis)
+}
+
+func (g GetBlockHeaders) Code() int { return 19 }
+
+type BlockHeaders []*types.Header
+
+func (bh BlockHeaders) Code() int { return 20 }
+
+// HashOrNumber is a combined field for specifying an origin block.
+type hashOrNumber struct {
+	Hash   common.Hash // Block hash from which to retrieve headers (excludes Number)
+	Number uint64      // Block hash from which to retrieve headers (excludes Hash)
+}
+
+// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the
+// two contained union fields.
+func (hn *hashOrNumber) EncodeRLP(w io.Writer) error {
+	if hn.Hash == (common.Hash{}) {
+		return rlp.Encode(w, hn.Number)
+	}
+	if hn.Number != 0 {
+		return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number)
+	}
+	return rlp.Encode(w, hn.Hash)
+}
+
+// DecodeRLP is a specialized decoder for hashOrNumber to decode the contents
+// into either a block hash or a block number.
+func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error {
+	_, size, _ := s.Kind()
+	origin, err := s.Raw()
+	if err == nil {
+		switch {
+		case size == 32:
+			err = rlp.DecodeBytes(origin, &hn.Hash)
+		case size <= 8:
+			err = rlp.DecodeBytes(origin, &hn.Number)
+		default:
+			err = fmt.Errorf("invalid input size %d for origin", size)
+		}
+	}
+	return err
+}
+
+// GetBlockBodies represents a GetBlockBodies request
+type GetBlockBodies []common.Hash
+
+func (gbb GetBlockBodies) Code() int { return 21 }
+
+// BlockBodies is the network packet for block content distribution.
+type BlockBodies []*types.Body
+
+func (bb BlockBodies) Code() int { return 22 }

+ 1 - 1
cmd/devp2p/main.go

@@ -81,7 +81,7 @@ func commandHasFlag(ctx *cli.Context, flag cli.Flag) bool {
 
 // getNodeArg handles the common case of a single node descriptor argument.
 func getNodeArg(ctx *cli.Context) *enode.Node {
-	if ctx.NArg() != 1 {
+	if ctx.NArg() < 1 {
 		exit("missing node as command-line argument")
 	}
 	n, err := parseNode(ctx.Args()[0])

+ 33 - 18
cmd/devp2p/rlpxcmd.go

@@ -19,9 +19,11 @@ package main
 import (
 	"fmt"
 	"net"
+	"os"
 
-	"github.com/ethereum/go-ethereum/common/hexutil"
+	"github.com/ethereum/go-ethereum/cmd/devp2p/internal/ethtest"
 	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/internal/utesting"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p/rlpx"
 	"github.com/ethereum/go-ethereum/rlp"
@@ -34,38 +36,42 @@ var (
 		Usage: "RLPx Commands",
 		Subcommands: []cli.Command{
 			rlpxPingCommand,
+			rlpxEthTestCommand,
 		},
 	}
 	rlpxPingCommand = cli.Command{
-		Name:      "ping",
-		Usage:     "Perform a RLPx handshake",
-		ArgsUsage: "<node>",
-		Action:    rlpxPing,
+		Name:   "ping",
+		Usage:  "ping <node>",
+		Action: rlpxPing,
+	}
+	rlpxEthTestCommand = cli.Command{
+		Name:      "eth-test",
+		Usage:     "Runs tests against a node",
+		ArgsUsage: "<node> <path_to_chain.rlp_file>",
+		Action:    rlpxEthTest,
+		Flags:     []cli.Flag{testPatternFlag},
 	}
 )
 
 func rlpxPing(ctx *cli.Context) error {
 	n := getNodeArg(ctx)
-
 	fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", n.IP(), n.TCP()))
 	if err != nil {
 		return err
 	}
 	conn := rlpx.NewConn(fd, n.Pubkey())
-
 	ourKey, _ := crypto.GenerateKey()
 	_, err = conn.Handshake(ourKey)
 	if err != nil {
 		return err
 	}
-
 	code, data, _, err := conn.Read()
 	if err != nil {
 		return err
 	}
 	switch code {
 	case 0:
-		var h devp2pHandshake
+		var h ethtest.Hello
 		if err := rlp.DecodeBytes(data, &h); err != nil {
 			return fmt.Errorf("invalid handshake: %v", err)
 		}
@@ -82,13 +88,22 @@ func rlpxPing(ctx *cli.Context) error {
 	return nil
 }
 
-// devp2pHandshake is the RLP structure of the devp2p protocol handshake.
-type devp2pHandshake struct {
-	Version    uint64
-	Name       string
-	Caps       []p2p.Cap
-	ListenPort uint64
-	ID         hexutil.Bytes // secp256k1 public key
-	// Ignore additional fields (for forward compatibility).
-	Rest []rlp.RawValue `rlp:"tail"`
+func rlpxEthTest(ctx *cli.Context) error {
+	if ctx.NArg() < 3 {
+		exit("missing path to chain.rlp as command-line argument")
+	}
+
+	suite := ethtest.NewSuite(getNodeArg(ctx), ctx.Args()[1], ctx.Args()[2])
+
+	// Filter and run test cases.
+	tests := suite.AllTests()
+	if ctx.IsSet(testPatternFlag.Name) {
+		tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name))
+	}
+	results := utesting.RunTests(tests, os.Stdout)
+	if fails := utesting.CountFailures(results); fails > 0 {
+		return fmt.Errorf("%v of %v tests passed.", len(tests)-fails, len(tests))
+	}
+	fmt.Printf("all tests passed\n")
+	return nil
 }

+ 2 - 13
core/forkid/forkid.go

@@ -65,19 +65,8 @@ type ID struct {
 // Filter is a fork id filter to validate a remotely advertised ID.
 type Filter func(id ID) error
 
-// NewID calculates the Ethereum fork ID from the chain config and head.
-func NewID(chain Blockchain) ID {
-	return newID(
-		chain.Config(),
-		chain.Genesis().Hash(),
-		chain.CurrentHeader().Number.Uint64(),
-	)
-}
-
-// newID is the internal version of NewID, which takes extracted values as its
-// arguments instead of a chain. The reason is to allow testing the IDs without
-// having to simulate an entire blockchain.
-func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID {
+// NewID calculates the Ethereum fork ID from the chain config, genesis hash, and head.
+func NewID(config *params.ChainConfig, genesis common.Hash, head uint64) ID {
 	// Calculate the starting checksum from the genesis hash
 	hash := crc32.ChecksumIEEE(genesis[:])
 

+ 1 - 1
core/forkid/forkid_test.go

@@ -118,7 +118,7 @@ func TestCreation(t *testing.T) {
 	}
 	for i, tt := range tests {
 		for j, ttt := range tt.cases {
-			if have := newID(tt.config, tt.genesis, ttt.head); have != ttt.want {
+			if have := NewID(tt.config, tt.genesis, ttt.head); have != ttt.want {
 				t.Errorf("test %d, case %d: fork ID mismatch: have %x, want %x", i, j, have, ttt.want)
 			}
 		}

+ 2 - 1
eth/discovery.go

@@ -60,7 +60,8 @@ func (eth *Ethereum) startEthEntryUpdate(ln *enode.LocalNode) {
 }
 
 func (eth *Ethereum) currentEthEntry() *ethEntry {
-	return &ethEntry{ForkID: forkid.NewID(eth.blockchain)}
+	return &ethEntry{ForkID: forkid.NewID(eth.blockchain.Config(), eth.blockchain.Genesis().Hash(),
+		eth.blockchain.CurrentHeader().Number.Uint64())}
 }
 
 // setupDiscovery creates the node discovery source for the eth protocol.

+ 2 - 1
eth/handler.go

@@ -319,7 +319,8 @@ func (pm *ProtocolManager) handle(p *peer) error {
 		number  = head.Number.Uint64()
 		td      = pm.blockchain.GetTd(hash, number)
 	)
-	if err := p.Handshake(pm.networkID, td, hash, genesis.Hash(), forkid.NewID(pm.blockchain), pm.forkFilter); err != nil {
+	forkID := forkid.NewID(pm.blockchain.Config(), pm.blockchain.Genesis().Hash(), pm.blockchain.CurrentHeader().Number.Uint64())
+	if err := p.Handshake(pm.networkID, td, hash, genesis.Hash(), forkID, pm.forkFilter); err != nil {
 		p.Log().Debug("Ethereum handshake failed", "err", err)
 		return err
 	}

+ 2 - 1
eth/helper_test.go

@@ -185,7 +185,8 @@ func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*te
 			head    = pm.blockchain.CurrentHeader()
 			td      = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
 		)
-		tp.handshake(nil, td, head.Hash(), genesis.Hash(), forkid.NewID(pm.blockchain), forkid.NewFilter(pm.blockchain))
+		forkID := forkid.NewID(pm.blockchain.Config(), pm.blockchain.Genesis().Hash(), pm.blockchain.CurrentHeader().Number.Uint64())
+		tp.handshake(nil, td, head.Hash(), genesis.Hash(), forkID, forkid.NewFilter(pm.blockchain))
 	}
 	return tp, errc
 }

+ 1 - 1
eth/protocol_test.go

@@ -104,7 +104,7 @@ func TestStatusMsgErrors64(t *testing.T) {
 		genesis = pm.blockchain.Genesis()
 		head    = pm.blockchain.CurrentHeader()
 		td      = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
-		forkID  = forkid.NewID(pm.blockchain)
+		forkID  = forkid.NewID(pm.blockchain.Config(), pm.blockchain.Genesis().Hash(), pm.blockchain.CurrentHeader().Number.Uint64())
 	)
 	defer pm.Stop()