Kaynağa Gözat

trie: fix for range proof (#21107)

* trie: fix for range proof

* trie: fix typo
gary rong 5 yıl önce
ebeveyn
işleme
070a5e1252
2 değiştirilmiş dosya ile 73 ekleme ve 56 silme
  1. 44 29
      trie/proof.go
  2. 29 27
      trie/proof_test.go

+ 44 - 29
trie/proof.go

@@ -219,54 +219,69 @@ func unsetInternal(n node, left []byte, right []byte) error {
 	if len(left) != len(right) {
 		return errors.New("inconsistent edge path")
 	}
-	// Step down to the fork point
-	prefix, pos := prefixLen(left, right), 0
-	var parent node
+	// Step down to the fork point. There are two scenarios can happen:
+	// - the fork point is a shortnode: the left proof MUST point to a
+	//   non-existent key and the key doesn't match with the shortnode
+	// - the fork point is a fullnode: the left proof can point to an
+	//   existent key or not.
+	var (
+		pos    = 0
+		parent node
+	)
+findFork:
 	for {
-		if pos >= prefix {
-			break
-		}
 		switch rn := (n).(type) {
 		case *shortNode:
+			// The right proof must point to an existent key.
 			if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) {
 				return errors.New("invalid edge path")
 			}
+			rn.flags = nodeFlag{dirty: true}
 			// Special case, the non-existent proof points to the same path
 			// as the existent proof, but the path of existent proof is longer.
-			// In this case, truncate the extra path(it should be recovered
-			// by node insertion).
+			// In this case, the fork point is this shortnode.
 			if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) {
-				fn := parent.(*fullNode)
-				fn.Children[left[pos-1]] = nil
-				return nil
+				break findFork
 			}
-			rn.flags = nodeFlag{dirty: true}
 			parent = n
 			n, pos = rn.Val, pos+len(rn.Key)
 		case *fullNode:
+			leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]]
+			// The right proof must point to an existent key.
+			if rightnode == nil {
+				return errors.New("invalid edge path")
+			}
 			rn.flags = nodeFlag{dirty: true}
+			if leftnode != rightnode {
+				break findFork
+			}
 			parent = n
-			n, pos = rn.Children[right[pos]], pos+1
+			n, pos = rn.Children[left[pos]], pos+1
 		default:
 			panic(fmt.Sprintf("%T: invalid node: %v", n, n))
 		}
 	}
-	fn, ok := n.(*fullNode)
-	if !ok {
-		return errors.New("the fork point must be a fullnode")
-	}
-	// Find the fork point! Unset all intermediate references
-	for i := left[prefix] + 1; i < right[prefix]; i++ {
-		fn.Children[i] = nil
-	}
-	fn.flags = nodeFlag{dirty: true}
-	if err := unset(fn, fn.Children[left[prefix]], left[prefix:], 1, false); err != nil {
-		return err
-	}
-	if err := unset(fn, fn.Children[right[prefix]], right[prefix:], 1, true); err != nil {
-		return err
+	switch rn := n.(type) {
+	case *shortNode:
+		if _, ok := rn.Val.(valueNode); ok {
+			parent.(*fullNode).Children[right[pos-1]] = nil
+			return nil
+		}
+		return unset(rn, rn.Val, right[pos:], len(rn.Key), true)
+	case *fullNode:
+		for i := left[pos] + 1; i < right[pos]; i++ {
+			rn.Children[i] = nil
+		}
+		if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil {
+			return err
+		}
+		if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil {
+			return err
+		}
+		return nil
+	default:
+		panic(fmt.Sprintf("%T: invalid node: %v", n, n))
 	}
-	return nil
 }
 
 // unset removes all internal node references either the left most or right most.
@@ -314,8 +329,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error
 				// The key of fork shortnode is less than the
 				// path(it doesn't belong to the range), keep
 				// it with the cached hash available.
-				return nil
 			}
+			return nil
 		}
 		if _, ok := cld.Val.(valueNode); ok {
 			fn := parent.(*fullNode)

+ 29 - 27
trie/proof_test.go

@@ -397,33 +397,35 @@ func TestAllElementsProof(t *testing.T) {
 
 // TestSingleSideRangeProof tests the range starts from zero.
 func TestSingleSideRangeProof(t *testing.T) {
-	trie := new(Trie)
-	var entries entrySlice
-	for i := 0; i < 4096; i++ {
-		value := &kv{randBytes(32), randBytes(20), false}
-		trie.Update(value.k, value.v)
-		entries = append(entries, value)
-	}
-	sort.Sort(entries)
-
-	var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
-	for _, pos := range cases {
-		firstProof, lastProof := memorydb.New(), memorydb.New()
-		if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil {
-			t.Fatalf("Failed to prove the first node %v", err)
-		}
-		if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil {
-			t.Fatalf("Failed to prove the first node %v", err)
-		}
-		k := make([][]byte, 0)
-		v := make([][]byte, 0)
-		for i := 0; i <= pos; i++ {
-			k = append(k, entries[i].k)
-			v = append(v, entries[i].v)
-		}
-		err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof)
-		if err != nil {
-			t.Fatalf("Expected no error, got %v", err)
+	for i := 0; i < 64; i++ {
+		trie := new(Trie)
+		var entries entrySlice
+		for i := 0; i < 4096; i++ {
+			value := &kv{randBytes(32), randBytes(20), false}
+			trie.Update(value.k, value.v)
+			entries = append(entries, value)
+		}
+		sort.Sort(entries)
+
+		var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
+		for _, pos := range cases {
+			firstProof, lastProof := memorydb.New(), memorydb.New()
+			if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil {
+				t.Fatalf("Failed to prove the first node %v", err)
+			}
+			if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil {
+				t.Fatalf("Failed to prove the first node %v", err)
+			}
+			k := make([][]byte, 0)
+			v := make([][]byte, 0)
+			for i := 0; i <= pos; i++ {
+				k = append(k, entries[i].k)
+				v = append(v, entries[i].v)
+			}
+			err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof)
+			if err != nil {
+				t.Fatalf("Expected no error, got %v", err)
+			}
 		}
 	}
 }