|
|
@@ -17,11 +17,19 @@
|
|
|
package state
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
+ "encoding/binary"
|
|
|
+ "fmt"
|
|
|
+ "math"
|
|
|
"math/big"
|
|
|
+ "math/rand"
|
|
|
+ "reflect"
|
|
|
+ "strings"
|
|
|
"testing"
|
|
|
+ "testing/quick"
|
|
|
|
|
|
"github.com/ethereum/go-ethereum/common"
|
|
|
- "github.com/ethereum/go-ethereum/crypto"
|
|
|
+ "github.com/ethereum/go-ethereum/core/vm"
|
|
|
"github.com/ethereum/go-ethereum/ethdb"
|
|
|
)
|
|
|
|
|
|
@@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) {
|
|
|
|
|
|
// Update it with some accounts
|
|
|
for i := byte(0); i < 255; i++ {
|
|
|
- obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
|
|
|
- obj.AddBalance(big.NewInt(int64(11 * i)))
|
|
|
- obj.SetNonce(uint64(42 * i))
|
|
|
+ addr := common.BytesToAddress([]byte{i})
|
|
|
+ state.AddBalance(addr, big.NewInt(int64(11*i)))
|
|
|
+ state.SetNonce(addr, uint64(42*i))
|
|
|
if i%2 == 0 {
|
|
|
- obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
|
|
|
+ state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
|
|
|
}
|
|
|
if i%3 == 0 {
|
|
|
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
|
|
|
+ state.SetCode(addr, []byte{i, i, i, i, i})
|
|
|
}
|
|
|
- state.UpdateStateObject(obj)
|
|
|
+ state.IntermediateRoot()
|
|
|
}
|
|
|
// Ensure that no data was leaked into the database
|
|
|
for _, key := range db.Keys() {
|
|
|
@@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) {
|
|
|
transState, _ := New(common.Hash{}, transDb)
|
|
|
finalState, _ := New(common.Hash{}, finalDb)
|
|
|
|
|
|
- // Update the states with some objects
|
|
|
- for i := byte(0); i < 255; i++ {
|
|
|
- // Create a new state object with some data into the transition database
|
|
|
- obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
|
|
|
- obj.SetBalance(big.NewInt(int64(11 * i)))
|
|
|
- obj.SetNonce(uint64(42 * i))
|
|
|
+ modify := func(state *StateDB, addr common.Address, i, tweak byte) {
|
|
|
+ state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
|
|
|
+ state.SetNonce(addr, uint64(42*i+tweak))
|
|
|
if i%2 == 0 {
|
|
|
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0}))
|
|
|
+ state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
|
|
|
+ state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
|
|
|
}
|
|
|
if i%3 == 0 {
|
|
|
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0})
|
|
|
+ state.SetCode(addr, []byte{i, i, i, i, i, tweak})
|
|
|
}
|
|
|
- transState.UpdateStateObject(obj)
|
|
|
+ }
|
|
|
|
|
|
- // Overwrite all the data with new values in the transition database
|
|
|
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
|
|
|
- obj.SetNonce(uint64(42*i + 1))
|
|
|
- if i%2 == 0 {
|
|
|
- obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{})
|
|
|
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
|
|
|
- }
|
|
|
- if i%3 == 0 {
|
|
|
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
|
|
|
- }
|
|
|
- transState.UpdateStateObject(obj)
|
|
|
+ // Modify the transient state.
|
|
|
+ for i := byte(0); i < 255; i++ {
|
|
|
+ modify(transState, common.Address{byte(i)}, i, 0)
|
|
|
+ }
|
|
|
+ // Write modifications to trie.
|
|
|
+ transState.IntermediateRoot()
|
|
|
|
|
|
- // Create the final state object directly in the final database
|
|
|
- obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
|
|
|
- obj.SetBalance(big.NewInt(int64(11*i + 1)))
|
|
|
- obj.SetNonce(uint64(42*i + 1))
|
|
|
- if i%2 == 0 {
|
|
|
- obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
|
|
|
- }
|
|
|
- if i%3 == 0 {
|
|
|
- obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
|
|
|
- }
|
|
|
- finalState.UpdateStateObject(obj)
|
|
|
+ // Overwrite all the data with new values in the transient database.
|
|
|
+ for i := byte(0); i < 255; i++ {
|
|
|
+ modify(transState, common.Address{byte(i)}, i, 99)
|
|
|
+ modify(finalState, common.Address{byte(i)}, i, 99)
|
|
|
}
|
|
|
+
|
|
|
+ // Commit and cross check the databases.
|
|
|
if _, err := transState.Commit(); err != nil {
|
|
|
t.Fatalf("failed to commit transition state: %v", err)
|
|
|
}
|
|
|
if _, err := finalState.Commit(); err != nil {
|
|
|
t.Fatalf("failed to commit final state: %v", err)
|
|
|
}
|
|
|
- // Cross check the databases to ensure they are the same
|
|
|
for _, key := range finalDb.Keys() {
|
|
|
if _, err := transDb.Get(key); err != nil {
|
|
|
val, _ := finalDb.Get(key)
|
|
|
@@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestSnapshotRandom(t *testing.T) {
|
|
|
+ config := &quick.Config{MaxCount: 1000}
|
|
|
+ err := quick.Check((*snapshotTest).run, config)
|
|
|
+ if cerr, ok := err.(*quick.CheckError); ok {
|
|
|
+ test := cerr.In[0].(*snapshotTest)
|
|
|
+ t.Errorf("%v:\n%s", test.err, test)
|
|
|
+ } else if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
|
|
|
+// captured by the snapshot. Instances of this test with pseudorandom content are created
|
|
|
+// by Generate.
|
|
|
+//
|
|
|
+// The test works as follows:
|
|
|
+//
|
|
|
+// A new state is created and all actions are applied to it. Several snapshots are taken
|
|
|
+// in between actions. The test then reverts each snapshot. For each snapshot the actions
|
|
|
+// leading up to it are replayed on a fresh, empty state. The behaviour of all public
|
|
|
+// accessor methods on the reverted state must match the return value of the equivalent
|
|
|
+// methods on the replayed state.
|
|
|
+type snapshotTest struct {
|
|
|
+ addrs []common.Address // all account addresses
|
|
|
+ actions []testAction // modifications to the state
|
|
|
+ snapshots []int // actions indexes at which snapshot is taken
|
|
|
+ err error // failure details are reported through this field
|
|
|
+}
|
|
|
+
|
|
|
+type testAction struct {
|
|
|
+ name string
|
|
|
+ fn func(testAction, *StateDB)
|
|
|
+ args []int64
|
|
|
+ noAddr bool
|
|
|
+}
|
|
|
+
|
|
|
+// newTestAction creates a random action that changes state.
|
|
|
+func newTestAction(addr common.Address, r *rand.Rand) testAction {
|
|
|
+ actions := []testAction{
|
|
|
+ {
|
|
|
+ name: "SetBalance",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ s.SetBalance(addr, big.NewInt(a.args[0]))
|
|
|
+ },
|
|
|
+ args: make([]int64, 1),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "AddBalance",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ s.AddBalance(addr, big.NewInt(a.args[0]))
|
|
|
+ },
|
|
|
+ args: make([]int64, 1),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "SetNonce",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ s.SetNonce(addr, uint64(a.args[0]))
|
|
|
+ },
|
|
|
+ args: make([]int64, 1),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "SetState",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ var key, val common.Hash
|
|
|
+ binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
|
|
|
+ binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
|
|
|
+ s.SetState(addr, key, val)
|
|
|
+ },
|
|
|
+ args: make([]int64, 2),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "SetCode",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ code := make([]byte, 16)
|
|
|
+ binary.BigEndian.PutUint64(code, uint64(a.args[0]))
|
|
|
+ binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
|
|
|
+ s.SetCode(addr, code)
|
|
|
+ },
|
|
|
+ args: make([]int64, 2),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "CreateAccount",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ s.CreateAccount(addr)
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "Delete",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ s.Delete(addr)
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "AddRefund",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ s.AddRefund(big.NewInt(a.args[0]))
|
|
|
+ },
|
|
|
+ args: make([]int64, 1),
|
|
|
+ noAddr: true,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "AddLog",
|
|
|
+ fn: func(a testAction, s *StateDB) {
|
|
|
+ data := make([]byte, 2)
|
|
|
+ binary.BigEndian.PutUint16(data, uint16(a.args[0]))
|
|
|
+ s.AddLog(&vm.Log{Address: addr, Data: data})
|
|
|
+ },
|
|
|
+ args: make([]int64, 1),
|
|
|
+ },
|
|
|
+ }
|
|
|
+ action := actions[r.Intn(len(actions))]
|
|
|
+ var nameargs []string
|
|
|
+ if !action.noAddr {
|
|
|
+ nameargs = append(nameargs, addr.Hex())
|
|
|
+ }
|
|
|
+ for _, i := range action.args {
|
|
|
+ action.args[i] = rand.Int63n(100)
|
|
|
+ nameargs = append(nameargs, fmt.Sprint(action.args[i]))
|
|
|
+ }
|
|
|
+ action.name += strings.Join(nameargs, ", ")
|
|
|
+ return action
|
|
|
+}
|
|
|
+
|
|
|
+// Generate returns a new snapshot test of the given size. All randomness is
|
|
|
+// derived from r.
|
|
|
+func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
|
|
|
+ // Generate random actions.
|
|
|
+ addrs := make([]common.Address, 50)
|
|
|
+ for i := range addrs {
|
|
|
+ addrs[i][0] = byte(i)
|
|
|
+ }
|
|
|
+ actions := make([]testAction, size)
|
|
|
+ for i := range actions {
|
|
|
+ addr := addrs[r.Intn(len(addrs))]
|
|
|
+ actions[i] = newTestAction(addr, r)
|
|
|
+ }
|
|
|
+ // Generate snapshot indexes.
|
|
|
+ nsnapshots := int(math.Sqrt(float64(size)))
|
|
|
+ if size > 0 && nsnapshots == 0 {
|
|
|
+ nsnapshots = 1
|
|
|
+ }
|
|
|
+ snapshots := make([]int, nsnapshots)
|
|
|
+ snaplen := len(actions) / nsnapshots
|
|
|
+ for i := range snapshots {
|
|
|
+ // Try to place the snapshots some number of actions apart from each other.
|
|
|
+ snapshots[i] = (i * snaplen) + r.Intn(snaplen)
|
|
|
+ }
|
|
|
+ return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
|
|
|
+}
|
|
|
+
|
|
|
+func (test *snapshotTest) String() string {
|
|
|
+ out := new(bytes.Buffer)
|
|
|
+ sindex := 0
|
|
|
+ for i, action := range test.actions {
|
|
|
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
|
|
|
+ fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
|
|
|
+ sindex++
|
|
|
+ }
|
|
|
+ fmt.Fprintf(out, "%4d: %s\n", i, action.name)
|
|
|
+ }
|
|
|
+ return out.String()
|
|
|
+}
|
|
|
+
|
|
|
+func (test *snapshotTest) run() bool {
|
|
|
+ // Run all actions and create snapshots.
|
|
|
+ var (
|
|
|
+ db, _ = ethdb.NewMemDatabase()
|
|
|
+ state, _ = New(common.Hash{}, db)
|
|
|
+ snapshotRevs = make([]int, len(test.snapshots))
|
|
|
+ sindex = 0
|
|
|
+ )
|
|
|
+ for i, action := range test.actions {
|
|
|
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
|
|
|
+ snapshotRevs[sindex] = state.Snapshot()
|
|
|
+ sindex++
|
|
|
+ }
|
|
|
+ action.fn(action, state)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Revert all snapshots in reverse order. Each revert must yield a state
|
|
|
+ // that is equivalent to fresh state with all actions up the snapshot applied.
|
|
|
+ for sindex--; sindex >= 0; sindex-- {
|
|
|
+ checkstate, _ := New(common.Hash{}, db)
|
|
|
+ for _, action := range test.actions[:test.snapshots[sindex]] {
|
|
|
+ action.fn(action, checkstate)
|
|
|
+ }
|
|
|
+ state.RevertToSnapshot(snapshotRevs[sindex])
|
|
|
+ if err := test.checkEqual(state, checkstate); err != nil {
|
|
|
+ test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+// checkEqual checks that methods of state and checkstate return the same values.
|
|
|
+func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
|
|
|
+ for _, addr := range test.addrs {
|
|
|
+ var err error
|
|
|
+ checkeq := func(op string, a, b interface{}) bool {
|
|
|
+ if err == nil && !reflect.DeepEqual(a, b) {
|
|
|
+ err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ // Check basic accessor methods.
|
|
|
+ checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
|
|
|
+ checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr))
|
|
|
+ checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
|
|
|
+ checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
|
|
|
+ checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
|
|
|
+ checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
|
|
|
+ checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
|
|
|
+ // Check storage.
|
|
|
+ if obj := state.GetStateObject(addr); obj != nil {
|
|
|
+ obj.ForEachStorage(func(key, val common.Hash) bool {
|
|
|
+ return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
|
|
|
+ })
|
|
|
+ checkobj := checkstate.GetStateObject(addr)
|
|
|
+ checkobj.ForEachStorage(func(key, checkval common.Hash) bool {
|
|
|
+ return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
|
|
|
+ })
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
|
|
|
+ return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
|
|
|
+ state.GetRefund(), checkstate.GetRefund())
|
|
|
+ }
|
|
|
+ if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
|
|
|
+ return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
|
|
|
+ state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|