Bladeren bron

trie: make stacktrie support binary marshal/unmarshal (#22685)

Martin Holst Swende 4 jaren geleden
bovenliggende
commit
581539c6ee
2 gewijzigde bestanden met toevoegingen van 140 en 1 verwijderingen
  1. 94 0
      trie/stacktrie.go
  2. 46 1
      trie/stacktrie_test.go

+ 94 - 0
trie/stacktrie.go

@@ -17,8 +17,12 @@
 package trie
 
 import (
+	"bufio"
+	"bytes"
+	"encoding/gob"
 	"errors"
 	"fmt"
+	"io"
 	"sync"
 
 	"github.com/ethereum/go-ethereum/common"
@@ -66,6 +70,96 @@ func NewStackTrie(db ethdb.KeyValueWriter) *StackTrie {
 	}
 }
 
+// NewFromBinary initialises a serialized stacktrie with the given db.
+func NewFromBinary(data []byte, db ethdb.KeyValueWriter) (*StackTrie, error) {
+	var st StackTrie
+	if err := st.UnmarshalBinary(data); err != nil {
+		return nil, err
+	}
+	// If a database is used, we need to recursively add it to every child
+	if db != nil {
+		st.setDb(db)
+	}
+	return &st, nil
+}
+
+// MarshalBinary implements encoding.BinaryMarshaler
+func (st *StackTrie) MarshalBinary() (data []byte, err error) {
+	var (
+		b bytes.Buffer
+		w = bufio.NewWriter(&b)
+	)
+	if err := gob.NewEncoder(w).Encode(struct {
+		Nodetype  uint8
+		Val       []byte
+		Key       []byte
+		KeyOffset uint8
+	}{
+		st.nodeType,
+		st.val,
+		st.key,
+		uint8(st.keyOffset),
+	}); err != nil {
+		return nil, err
+	}
+	for _, child := range st.children {
+		if child == nil {
+			w.WriteByte(0)
+			continue
+		}
+		w.WriteByte(1)
+		if childData, err := child.MarshalBinary(); err != nil {
+			return nil, err
+		} else {
+			w.Write(childData)
+		}
+	}
+	w.Flush()
+	return b.Bytes(), nil
+}
+
+// UnmarshalBinary implements encoding.BinaryUnmarshaler
+func (st *StackTrie) UnmarshalBinary(data []byte) error {
+	r := bytes.NewReader(data)
+	return st.unmarshalBinary(r)
+}
+
+func (st *StackTrie) unmarshalBinary(r io.Reader) error {
+	var dec struct {
+		Nodetype  uint8
+		Val       []byte
+		Key       []byte
+		KeyOffset uint8
+	}
+	gob.NewDecoder(r).Decode(&dec)
+	st.nodeType = dec.Nodetype
+	st.val = dec.Val
+	st.key = dec.Key
+	st.keyOffset = int(dec.KeyOffset)
+
+	var hasChild = make([]byte, 1)
+	for i := range st.children {
+		if _, err := r.Read(hasChild); err != nil {
+			return err
+		} else if hasChild[0] == 0 {
+			continue
+		}
+		var child StackTrie
+		child.unmarshalBinary(r)
+		st.children[i] = &child
+	}
+	return nil
+}
+
+func (st *StackTrie) setDb(db ethdb.KeyValueWriter) {
+	st.db = db
+	for _, child := range st.children {
+		if child != nil {
+			child.setDb(db)
+		}
+	}
+}
+
 func newLeaf(ko int, key, val []byte, db ethdb.KeyValueWriter) *StackTrie {
 	st := stackTrieFromPool(db)
 	st.nodeType = leafNode

+ 46 - 1
trie/stacktrie_test.go

@@ -151,7 +151,6 @@ func TestStacktrieNotModifyValues(t *testing.T) {
 			return big.NewInt(int64(i)).Bytes()
 		}
 	}
-
 	for i := 0; i < 1000; i++ {
 		key := common.BigToHash(keyB)
 		value := getValue(i)
@@ -168,5 +167,51 @@ func TestStacktrieNotModifyValues(t *testing.T) {
 		if !bytes.Equal(have, want) {
 			t.Fatalf("item %d, have %#x want %#x", i, have, want)
 		}
+
+	}
+}
+
+// TestStacktrieSerialization tests that the stacktrie works well if we
+// serialize/unserialize it a lot
+func TestStacktrieSerialization(t *testing.T) {
+	var (
+		st       = NewStackTrie(nil)
+		nt, _    = New(common.Hash{}, NewDatabase(memorydb.New()))
+		keyB     = big.NewInt(1)
+		keyDelta = big.NewInt(1)
+		vals     [][]byte
+		keys     [][]byte
+	)
+	getValue := func(i int) []byte {
+		if i%2 == 0 { // large
+			return crypto.Keccak256(big.NewInt(int64(i)).Bytes())
+		} else { //small
+			return big.NewInt(int64(i)).Bytes()
+		}
+	}
+	for i := 0; i < 10; i++ {
+		vals = append(vals, getValue(i))
+		keys = append(keys, common.BigToHash(keyB).Bytes())
+		keyB = keyB.Add(keyB, keyDelta)
+		keyDelta.Add(keyDelta, common.Big1)
+	}
+	for i, k := range keys {
+		nt.TryUpdate(k, common.CopyBytes(vals[i]))
+	}
+
+	for i, k := range keys {
+		blob, err := st.MarshalBinary()
+		if err != nil {
+			t.Fatal(err)
+		}
+		newSt, err := NewFromBinary(blob, nil)
+		if err != nil {
+			t.Fatal(err)
+		}
+		st = newSt
+		st.TryUpdate(k, common.CopyBytes(vals[i]))
+	}
+	if have, want := st.Hash(), nt.Hash(); have != want {
+		t.Fatalf("have %#x want %#x", have, want)
 	}
 }