Эх сурвалжийг харах

trie: ensure resolved nodes stay loaded

Commit 40cdcf1183 broke the optimisation which kept nodes resolved
during Get in the trie. The decoder assigned cache generation 0
unconditionally, causing resolved nodes to get flushed on Commit.

This commit fixes it and adds two tests.
Felix Lange 9 жил өмнө
parent
commit
177cab5fe7
6 өөрчлөгдсөн 95 нэмэгдсэн , 43 устгасан
  1. 1 1
      trie/hasher.go
  2. 13 13
      trie/node.go
  3. 1 1
      trie/proof.go
  4. 3 3
      trie/sync.go
  5. 7 4
      trie/trie.go
  6. 70 21
      trie/trie_test.go

+ 1 - 1
trie/hasher.go

@@ -58,7 +58,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
 			return hash, n, nil
 		}
 		if n.canUnload(h.cachegen, h.cachelimit) {
-			// Evict the node from cache. All of its subnodes will have a lower or equal
+			// Unload the node from cache. All of its subnodes will have a lower or equal
 			// cache generation number.
 			return hash, hash, nil
 		}

+ 13 - 13
trie/node.go

@@ -104,8 +104,8 @@ func (n valueNode) fstring(ind string) string {
 	return fmt.Sprintf("%x ", []byte(n))
 }
 
-func mustDecodeNode(hash, buf []byte) node {
-	n, err := decodeNode(hash, buf)
+func mustDecodeNode(hash, buf []byte, cachegen uint16) node {
+	n, err := decodeNode(hash, buf, cachegen)
 	if err != nil {
 		panic(fmt.Sprintf("node %x: %v", hash, err))
 	}
@@ -113,7 +113,7 @@ func mustDecodeNode(hash, buf []byte) node {
 }
 
 // decodeNode parses the RLP encoding of a trie node.
-func decodeNode(hash, buf []byte) (node, error) {
+func decodeNode(hash, buf []byte, cachegen uint16) (node, error) {
 	if len(buf) == 0 {
 		return nil, io.ErrUnexpectedEOF
 	}
@@ -123,22 +123,22 @@ func decodeNode(hash, buf []byte) (node, error) {
 	}
 	switch c, _ := rlp.CountValues(elems); c {
 	case 2:
-		n, err := decodeShort(hash, buf, elems)
+		n, err := decodeShort(hash, buf, elems, cachegen)
 		return n, wrapError(err, "short")
 	case 17:
-		n, err := decodeFull(hash, buf, elems)
+		n, err := decodeFull(hash, buf, elems, cachegen)
 		return n, wrapError(err, "full")
 	default:
 		return nil, fmt.Errorf("invalid number of list elements: %v", c)
 	}
 }
 
-func decodeShort(hash, buf, elems []byte) (node, error) {
+func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
 	kbuf, rest, err := rlp.SplitString(elems)
 	if err != nil {
 		return nil, err
 	}
-	flag := nodeFlag{hash: hash}
+	flag := nodeFlag{hash: hash, gen: cachegen}
 	key := compactDecode(kbuf)
 	if key[len(key)-1] == 16 {
 		// value node
@@ -148,17 +148,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
 		}
 		return &shortNode{key, append(valueNode{}, val...), flag}, nil
 	}
-	r, _, err := decodeRef(rest)
+	r, _, err := decodeRef(rest, cachegen)
 	if err != nil {
 		return nil, wrapError(err, "val")
 	}
 	return &shortNode{key, r, flag}, nil
 }
 
-func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
-	n := &fullNode{flags: nodeFlag{hash: hash}}
+func decodeFull(hash, buf, elems []byte, cachegen uint16) (*fullNode, error) {
+	n := &fullNode{flags: nodeFlag{hash: hash, gen: cachegen}}
 	for i := 0; i < 16; i++ {
-		cld, rest, err := decodeRef(elems)
+		cld, rest, err := decodeRef(elems, cachegen)
 		if err != nil {
 			return n, wrapError(err, fmt.Sprintf("[%d]", i))
 		}
@@ -176,7 +176,7 @@ func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
 
 const hashLen = len(common.Hash{})
 
-func decodeRef(buf []byte) (node, []byte, error) {
+func decodeRef(buf []byte, cachegen uint16) (node, []byte, error) {
 	kind, val, rest, err := rlp.Split(buf)
 	if err != nil {
 		return nil, buf, err
@@ -189,7 +189,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
 			err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
 			return nil, buf, err
 		}
-		n, err := decodeNode(nil, buf)
+		n, err := decodeNode(nil, buf, cachegen)
 		return n, rest, err
 	case kind == rlp.String && len(val) == 0:
 		// empty node

+ 1 - 1
trie/proof.go

@@ -101,7 +101,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
 		if !bytes.Equal(sha.Sum(nil), wantHash) {
 			return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
 		}
-		n, err := decodeNode(wantHash, buf)
+		n, err := decodeNode(wantHash, buf, 0)
 		if err != nil {
 			return nil, fmt.Errorf("bad proof node %d: %v", i, err)
 		}

+ 3 - 3
trie/sync.go

@@ -82,7 +82,7 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c
 	}
 	key := root.Bytes()
 	blob, _ := s.database.Get(key)
-	if local, err := decodeNode(key, blob); local != nil && err == nil {
+	if local, err := decodeNode(key, blob, 0); local != nil && err == nil {
 		return
 	}
 	// Assemble the new sub-trie sync request
@@ -158,7 +158,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) {
 			continue
 		}
 		// Decode the node data content and update the request
-		node, err := decodeNode(item.Hash[:], item.Data)
+		node, err := decodeNode(item.Hash[:], item.Data, 0)
 		if err != nil {
 			return i, err
 		}
@@ -246,7 +246,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) {
 		if node, ok := (*child.node).(hashNode); ok {
 			// Try to resolve the node from the local database
 			blob, _ := s.database.Get(node)
-			if local, err := decodeNode(node[:], blob); local != nil && err == nil {
+			if local, err := decodeNode(node[:], blob, 0); local != nil && err == nil {
 				*child.node = local
 				continue
 			}

+ 7 - 4
trie/trie.go

@@ -144,14 +144,15 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
 		if err == nil && didResolve {
 			n = n.copy()
 			n.Val = newnode
+			n.flags.gen = t.cachegen
 		}
 		return value, n, didResolve, err
 	case *fullNode:
 		value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
 		if err == nil && didResolve {
 			n = n.copy()
+			n.flags.gen = t.cachegen
 			n.Children[key[pos]] = newnode
-
 		}
 		return value, n, didResolve, err
 	case hashNode:
@@ -247,7 +248,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
 			return false, n, err
 		}
 		n = n.copy()
-		n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
+		n.flags = t.newFlag()
+		n.Children[key[0]] = nn
 		return true, n, nil
 
 	case nil:
@@ -331,7 +333,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
 			return false, n, err
 		}
 		n = n.copy()
-		n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
+		n.flags = t.newFlag()
+		n.Children[key[0]] = nn
 
 		// Check how many non-nil entries are left after deleting and
 		// reduce the full node to a short node if only one entry is
@@ -427,7 +430,7 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) {
 			SuffixLen: len(suffix),
 		}
 	}
-	dec := mustDecodeNode(n, enc)
+	dec := mustDecodeNode(n, enc, t.cachegen)
 	return dec, nil
 }
 

+ 70 - 21
trie/trie_test.go

@@ -300,25 +300,6 @@ func TestReplication(t *testing.T) {
 	}
 }
 
-// Not an actual test
-func TestOutput(t *testing.T) {
-	t.Skip()
-
-	base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
-	trie := newEmpty()
-	for i := 0; i < 50; i++ {
-		updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
-	}
-	fmt.Println("############################## FULL ################################")
-	fmt.Println(trie.root)
-
-	trie.Commit()
-	fmt.Println("############################## SMALL ################################")
-	trie2, _ := New(trie.Hash(), trie.db)
-	getString(trie2, base+"20")
-	fmt.Println(trie2.root)
-}
-
 func TestLargeValue(t *testing.T) {
 	trie := newEmpty()
 	trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
@@ -326,14 +307,56 @@ func TestLargeValue(t *testing.T) {
 	trie.Hash()
 }
 
+type countingDB struct {
+	Database
+	gets map[string]int
+}
+
+func (db *countingDB) Get(key []byte) ([]byte, error) {
+	db.gets[string(key)]++
+	return db.Database.Get(key)
+}
+
+// TestCacheUnload checks that decoded nodes are unloaded after a
+// certain number of commit operations.
+func TestCacheUnload(t *testing.T) {
+	// Create test trie with two branches.
+	trie := newEmpty()
+	key1 := "---------------------------------"
+	key2 := "---some other branch"
+	updateString(trie, key1, "this is the branch of key1.")
+	updateString(trie, key2, "this is the branch of key2.")
+	root, _ := trie.Commit()
+
+	// Commit the trie repeatedly and access key1.
+	// The branch containing it is loaded from DB exactly two times:
+	// in the 0th and 6th iteration.
+	db := &countingDB{Database: trie.db, gets: make(map[string]int)}
+	trie, _ = New(root, db)
+	trie.SetCacheLimit(5)
+	for i := 0; i < 12; i++ {
+		getString(trie, key1)
+		trie.Commit()
+	}
+
+	// Check that it got loaded two times.
+	for dbkey, count := range db.gets {
+		if count != 2 {
+			t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2)
+		}
+	}
+}
+
+// randTest performs random trie operations.
+// Instances of this test are created by Generate.
+type randTest []randTestStep
+
 type randTestStep struct {
 	op    int
 	key   []byte // for opUpdate, opDelete, opGet
 	value []byte // for opUpdate
 }
 
-type randTest []randTestStep
-
 const (
 	opUpdate = iota
 	opDelete
@@ -342,6 +365,7 @@ const (
 	opHash
 	opReset
 	opItercheckhash
+	opCheckCacheInvariant
 	opMax // boundary value, not an actual op
 )
 
@@ -437,7 +461,32 @@ func runRandTest(rt randTest) bool {
 				fmt.Println("hashes not equal")
 				return false
 			}
+		case opCheckCacheInvariant:
+			return checkCacheInvariant(tr.root, tr.cachegen, 0)
+		}
+	}
+	return true
+}
+
+func checkCacheInvariant(n node, parentCachegen uint16, depth int) bool {
+	switch n := n.(type) {
+	case *shortNode:
+		if n.flags.gen > parentCachegen {
+			fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
+			return false
+		}
+		return checkCacheInvariant(n.Val, n.flags.gen, depth+1)
+	case *fullNode:
+		if n.flags.gen > parentCachegen {
+			fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
+			return false
+		}
+		for _, child := range n.Children {
+			if !checkCacheInvariant(child, n.flags.gen, depth+1) {
+				return false
+			}
 		}
+		return true
 	}
 	return true
 }