瀏覽代碼

eth, les: add error when accessing missing block state (#18346)

This change makes getBalance, getCode, getStorageAt, getProof,
call, getTransactionCount return an error if the block number in
the request doesn't exist. getHeaderByNumber still returns null
for missing headers.
Martin Holst Swende 6 年之前
父節點
當前提交
5036992b06
共有 3 個文件被更改,包括 161 次插入2 次删除
  1. 5 1
      eth/api_backend.go
  2. 151 0
      ethclient/ethclient_test.go
  3. 5 1
      les/api_backend.go

+ 5 - 1
eth/api_backend.go

@@ -18,6 +18,7 @@ package eth
 
 import (
 	"context"
+	"errors"
 	"math/big"
 
 	"github.com/ethereum/go-ethereum/accounts"
@@ -95,9 +96,12 @@ func (b *EthAPIBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.
 	}
 	// Otherwise resolve the block number and return its state
 	header, err := b.HeaderByNumber(ctx, blockNr)
-	if header == nil || err != nil {
+	if err != nil {
 		return nil, nil, err
 	}
+	if header == nil {
+		return nil, nil, errors.New("header not found")
+	}
 	stateDb, err := b.eth.BlockChain().StateAt(header.Root)
 	return stateDb, header, err
 }

+ 151 - 0
ethclient/ethclient_test.go

@@ -17,13 +17,24 @@
 package ethclient
 
 import (
+	"context"
+	"errors"
 	"fmt"
 	"math/big"
 	"reflect"
 	"testing"
+	"time"
 
 	"github.com/ethereum/go-ethereum"
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/consensus/ethash"
+	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/eth"
+	"github.com/ethereum/go-ethereum/node"
+	"github.com/ethereum/go-ethereum/params"
 )
 
 // Verify that Client implements the ethereum interfaces.
@@ -150,3 +161,143 @@ func TestToFilterArg(t *testing.T) {
 		})
 	}
 }
+
+var (
+	testKey, _  = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+	testAddr    = crypto.PubkeyToAddress(testKey.PublicKey)
+	testBalance = big.NewInt(2e10)
+)
+
+func newTestBackend(t *testing.T) (*node.Node, []*types.Block) {
+	// Generate test chain.
+	genesis, blocks := generateTestChain()
+
+	// Start Ethereum service.
+	var ethservice *eth.Ethereum
+	n, err := node.New(&node.Config{})
+	n.Register(func(ctx *node.ServiceContext) (node.Service, error) {
+		config := &eth.Config{Genesis: genesis}
+		config.Ethash.PowMode = ethash.ModeFake
+		ethservice, err = eth.New(ctx, config)
+		return ethservice, err
+	})
+
+	// Import the test chain.
+	if err := n.Start(); err != nil {
+		t.Fatalf("can't start test node: %v", err)
+	}
+	if _, err := ethservice.BlockChain().InsertChain(blocks[1:]); err != nil {
+		t.Fatalf("can't import test blocks: %v", err)
+	}
+	return n, blocks
+}
+
+func generateTestChain() (*core.Genesis, []*types.Block) {
+	db := rawdb.NewMemoryDatabase()
+	config := params.AllEthashProtocolChanges
+	genesis := &core.Genesis{
+		Config:    config,
+		Alloc:     core.GenesisAlloc{testAddr: {Balance: testBalance}},
+		ExtraData: []byte("test genesis"),
+		Timestamp: 9000,
+	}
+	generate := func(i int, g *core.BlockGen) {
+		g.OffsetTime(5)
+		g.SetExtra([]byte("test"))
+	}
+	gblock := genesis.ToBlock(db)
+	engine := ethash.NewFaker()
+	blocks, _ := core.GenerateChain(config, gblock, engine, db, 1, generate)
+	blocks = append([]*types.Block{gblock}, blocks...)
+	return genesis, blocks
+}
+
+func TestHeader(t *testing.T) {
+	backend, chain := newTestBackend(t)
+	client, _ := backend.Attach()
+	defer backend.Stop()
+	defer client.Close()
+
+	tests := map[string]struct {
+		block   *big.Int
+		want    *types.Header
+		wantErr error
+	}{
+		"genesis": {
+			block: big.NewInt(0),
+			want:  chain[0].Header(),
+		},
+		"first_block": {
+			block: big.NewInt(1),
+			want:  chain[1].Header(),
+		},
+		"future_block": {
+			block: big.NewInt(1000000000),
+			want:  nil,
+		},
+	}
+	for name, tt := range tests {
+		t.Run(name, func(t *testing.T) {
+			ec := NewClient(client)
+			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+			defer cancel()
+
+			got, err := ec.HeaderByNumber(ctx, tt.block)
+			if tt.wantErr != nil && (err == nil || err.Error() != tt.wantErr.Error()) {
+				t.Fatalf("HeaderByNumber(%v) error = %q, want %q", tt.block, err, tt.wantErr)
+			}
+			if got != nil && got.Number.Sign() == 0 {
+				got.Number = big.NewInt(0) // hack to make DeepEqual work
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Fatalf("HeaderByNumber(%v)\n   = %v\nwant %v", tt.block, got, tt.want)
+			}
+		})
+	}
+}
+
+func TestBalanceAt(t *testing.T) {
+	backend, _ := newTestBackend(t)
+	client, _ := backend.Attach()
+	defer backend.Stop()
+	defer client.Close()
+
+	tests := map[string]struct {
+		account common.Address
+		block   *big.Int
+		want    *big.Int
+		wantErr error
+	}{
+		"valid_account": {
+			account: testAddr,
+			block:   big.NewInt(1),
+			want:    testBalance,
+		},
+		"non_existent_account": {
+			account: common.Address{1},
+			block:   big.NewInt(1),
+			want:    big.NewInt(0),
+		},
+		"future_block": {
+			account: testAddr,
+			block:   big.NewInt(1000000000),
+			want:    big.NewInt(0),
+			wantErr: errors.New("header not found"),
+		},
+	}
+	for name, tt := range tests {
+		t.Run(name, func(t *testing.T) {
+			ec := NewClient(client)
+			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+			defer cancel()
+
+			got, err := ec.BalanceAt(ctx, tt.account, tt.block)
+			if tt.wantErr != nil && (err == nil || err.Error() != tt.wantErr.Error()) {
+				t.Fatalf("BalanceAt(%x, %v) error = %q, want %q", tt.account, tt.block, err, tt.wantErr)
+			}
+			if got.Cmp(tt.want) != 0 {
+				t.Fatalf("BalanceAt(%x, %v) = %v, want %v", tt.account, tt.block, got, tt.want)
+			}
+		})
+	}
+}

+ 5 - 1
les/api_backend.go

@@ -18,6 +18,7 @@ package les
 
 import (
 	"context"
+	"errors"
 	"math/big"
 
 	"github.com/ethereum/go-ethereum/accounts"
@@ -78,9 +79,12 @@ func (b *LesApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb
 
 func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) {
 	header, err := b.HeaderByNumber(ctx, blockNr)
-	if header == nil || err != nil {
+	if err != nil {
 		return nil, nil, err
 	}
+	if header == nil {
+		return nil, nil, errors.New("header not found")
+	}
 	return light.NewState(ctx, header, b.eth.odr), header, nil
 }