Quellcode durchsuchen

core/state/snapshot: fix binary iterator (#20970)

gary rong vor 5 Jahren
Ursprung
Commit
8a2e8faadd
2 geänderte Dateien mit 66 neuen und 12 gelöschten Zeilen
  1. 14 8
      core/state/snapshot/iterator_binary.go
  2. 52 4
      core/state/snapshot/iterator_test.go

+ 14 - 8
core/state/snapshot/iterator_binary.go

@@ -26,7 +26,7 @@ import (
 // a snapshot, which may or may npt be composed of multiple layers. Performance
 // wise this iterator is slow, it's meant for cross validating the fast one,
 type binaryAccountIterator struct {
-	a     *diffAccountIterator
+	a     AccountIterator
 	b     AccountIterator
 	aDone bool
 	bDone bool
@@ -40,10 +40,16 @@ func (dl *diffLayer) newBinaryAccountIterator() AccountIterator {
 	parent, ok := dl.parent.(*diffLayer)
 	if !ok {
 		// parent is the disk layer
-		return dl.AccountIterator(common.Hash{})
+		l := &binaryAccountIterator{
+			a: dl.AccountIterator(common.Hash{}),
+			b: dl.Parent().AccountIterator(common.Hash{}),
+		}
+		l.aDone = !l.a.Next()
+		l.bDone = !l.b.Next()
+		return l
 	}
 	l := &binaryAccountIterator{
-		a: dl.AccountIterator(common.Hash{}).(*diffAccountIterator),
+		a: dl.AccountIterator(common.Hash{}),
 		b: parent.newBinaryAccountIterator(),
 	}
 	l.aDone = !l.a.Next()
@@ -58,19 +64,18 @@ func (it *binaryAccountIterator) Next() bool {
 	if it.aDone && it.bDone {
 		return false
 	}
-	nextB := it.b.Hash()
 first:
-	nextA := it.a.Hash()
 	if it.aDone {
+		it.k = it.b.Hash()
 		it.bDone = !it.b.Next()
-		it.k = nextB
 		return true
 	}
 	if it.bDone {
+		it.k = it.a.Hash()
 		it.aDone = !it.a.Next()
-		it.k = nextA
 		return true
 	}
+	nextA, nextB := it.a.Hash(), it.b.Hash()
 	if diff := bytes.Compare(nextA[:], nextB[:]); diff < 0 {
 		it.aDone = !it.a.Next()
 		it.k = nextA
@@ -100,7 +105,8 @@ func (it *binaryAccountIterator) Hash() common.Hash {
 // nil if the iterated snapshot stack became stale (you can check Error after
 // to see if it failed or not).
 func (it *binaryAccountIterator) Account() []byte {
-	blob, err := it.a.layer.AccountRLP(it.k)
+	// The topmost iterator must be `diffAccountIterator`
+	blob, err := it.a.(*diffAccountIterator).layer.AccountRLP(it.k)
 	if err != nil {
 		it.fail = err
 		return nil

+ 52 - 4
core/state/snapshot/iterator_test.go

@@ -177,9 +177,22 @@ func TestAccountIteratorTraversal(t *testing.T) {
 	verifyIterator(t, 7, head.(*diffLayer).newBinaryAccountIterator())
 
 	it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{})
-	defer it.Release()
+	verifyIterator(t, 7, it)
+	it.Release()
+
+	// Test after persist some bottom-most layers into the disk,
+	// the functionalities still work.
+	limit := aggregatorMemoryLimit
+	defer func() {
+		aggregatorMemoryLimit = limit
+	}()
+	aggregatorMemoryLimit = 0 // Force pushing the bottom-most layer into disk
+	snaps.Cap(common.HexToHash("0x04"), 2)
+	verifyIterator(t, 7, head.(*diffLayer).newBinaryAccountIterator())
 
+	it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{})
 	verifyIterator(t, 7, it)
+	it.Release()
 }
 
 // TestAccountIteratorTraversalValues tests some multi-layer iteration, where we
@@ -242,8 +255,6 @@ func TestAccountIteratorTraversalValues(t *testing.T) {
 	snaps.Update(common.HexToHash("0x09"), common.HexToHash("0x08"), nil, h, nil)
 
 	it, _ := snaps.AccountIterator(common.HexToHash("0x09"), common.Hash{})
-	defer it.Release()
-
 	head := snaps.Snapshot(common.HexToHash("0x09"))
 	for it.Next() {
 		hash := it.Hash()
@@ -255,6 +266,29 @@ func TestAccountIteratorTraversalValues(t *testing.T) {
 			t.Fatalf("hash %x: account mismatch: have %x, want %x", hash, have, want)
 		}
 	}
+	it.Release()
+
+	// Test after persist some bottom-most layers into the disk,
+	// the functionalities still work.
+	limit := aggregatorMemoryLimit
+	defer func() {
+		aggregatorMemoryLimit = limit
+	}()
+	aggregatorMemoryLimit = 0 // Force pushing the bottom-most layer into disk
+	snaps.Cap(common.HexToHash("0x09"), 2)
+
+	it, _ = snaps.AccountIterator(common.HexToHash("0x09"), common.Hash{})
+	for it.Next() {
+		hash := it.Hash()
+		want, err := head.AccountRLP(hash)
+		if err != nil {
+			t.Fatalf("failed to retrieve expected account: %v", err)
+		}
+		if have := it.Account(); !bytes.Equal(want, have) {
+			t.Fatalf("hash %x: account mismatch: have %x, want %x", hash, have, want)
+		}
+	}
+	it.Release()
 }
 
 // This testcase is notorious, all layers contain the exact same 200 accounts.
@@ -289,9 +323,23 @@ func TestAccountIteratorLargeTraversal(t *testing.T) {
 	verifyIterator(t, 200, head.(*diffLayer).newBinaryAccountIterator())
 
 	it, _ := snaps.AccountIterator(common.HexToHash("0x80"), common.Hash{})
-	defer it.Release()
+	verifyIterator(t, 200, it)
+	it.Release()
+
+	// Test after persist some bottom-most layers into the disk,
+	// the functionalities still work.
+	limit := aggregatorMemoryLimit
+	defer func() {
+		aggregatorMemoryLimit = limit
+	}()
+	aggregatorMemoryLimit = 0 // Force pushing the bottom-most layer into disk
+	snaps.Cap(common.HexToHash("0x80"), 2)
 
+	verifyIterator(t, 200, head.(*diffLayer).newBinaryAccountIterator())
+
+	it, _ = snaps.AccountIterator(common.HexToHash("0x80"), common.Hash{})
 	verifyIterator(t, 200, it)
+	it.Release()
 }
 
 // TestAccountIteratorFlattening tests what happens when we