Эх сурвалжийг харах

core, eth/downloader: ensure state presence in ancestor lookup

Péter Szilágyi 10 жил өмнө
parent
commit
9e011ff1cd

+ 13 - 0
core/blockchain.go

@@ -587,6 +587,19 @@ func (bc *BlockChain) HasBlock(hash common.Hash) bool {
 	return bc.GetBlock(hash) != nil
 	return bc.GetBlock(hash) != nil
 }
 }
 
 
+// HasBlockAndState checks if a block and associated state trie is fully present
+// in the database or not, caching it if present.
+func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool {
+	// Check first that the block itself is known
+	block := bc.GetBlock(hash)
+	if block == nil {
+		return false
+	}
+	// Ensure the associated state is also present
+	_, err := state.New(block.Root(), bc.chainDb)
+	return err == nil
+}
+
 // GetBlock retrieves a block from the database by hash, caching it if found.
 // GetBlock retrieves a block from the database by hash, caching it if found.
 func (self *BlockChain) GetBlock(hash common.Hash) *types.Block {
 func (self *BlockChain) GetBlock(hash common.Hash) *types.Block {
 	// Short circuit if the block's already in the cache, retrieve otherwise
 	// Short circuit if the block's already in the cache, retrieve otherwise

+ 0 - 1
core/state/statedb.go

@@ -57,7 +57,6 @@ type StateDB struct {
 func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
 func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
 	tr, err := trie.NewSecure(root, db)
 	tr, err := trie.NewSecure(root, db)
 	if err != nil {
 	if err != nil {
-		glog.Errorf("can't create state trie with root %x: %v", root[:], err)
 		return nil, err
 		return nil, err
 	}
 	}
 	return &StateDB{
 	return &StateDB{

+ 52 - 52
eth/downloader/downloader.go

@@ -112,20 +112,20 @@ type Downloader struct {
 	syncStatsLock        sync.RWMutex // Lock protecting the sync stats fields
 	syncStatsLock        sync.RWMutex // Lock protecting the sync stats fields
 
 
 	// Callbacks
 	// Callbacks
-	hasHeader       headerCheckFn            // Checks if a header is present in the chain
-	hasBlock        blockCheckFn             // Checks if a block is present in the chain
-	getHeader       headerRetrievalFn        // Retrieves a header from the chain
-	getBlock        blockRetrievalFn         // Retrieves a block from the chain
-	headHeader      headHeaderRetrievalFn    // Retrieves the head header from the chain
-	headBlock       headBlockRetrievalFn     // Retrieves the head block from the chain
-	headFastBlock   headFastBlockRetrievalFn // Retrieves the head fast-sync block from the chain
-	commitHeadBlock headBlockCommitterFn     // Commits a manually assembled block as the chain head
-	getTd           tdRetrievalFn            // Retrieves the TD of a block from the chain
-	insertHeaders   headerChainInsertFn      // Injects a batch of headers into the chain
-	insertBlocks    blockChainInsertFn       // Injects a batch of blocks into the chain
-	insertReceipts  receiptChainInsertFn     // Injects a batch of blocks and their receipts into the chain
-	rollback        chainRollbackFn          // Removes a batch of recently added chain links
-	dropPeer        peerDropFn               // Drops a peer for misbehaving
+	hasHeader        headerCheckFn            // Checks if a header is present in the chain
+	hasBlockAndState blockAndStateCheckFn     // Checks if a block and associated state is present in the chain
+	getHeader        headerRetrievalFn        // Retrieves a header from the chain
+	getBlock         blockRetrievalFn         // Retrieves a block from the chain
+	headHeader       headHeaderRetrievalFn    // Retrieves the head header from the chain
+	headBlock        headBlockRetrievalFn     // Retrieves the head block from the chain
+	headFastBlock    headFastBlockRetrievalFn // Retrieves the head fast-sync block from the chain
+	commitHeadBlock  headBlockCommitterFn     // Commits a manually assembled block as the chain head
+	getTd            tdRetrievalFn            // Retrieves the TD of a block from the chain
+	insertHeaders    headerChainInsertFn      // Injects a batch of headers into the chain
+	insertBlocks     blockChainInsertFn       // Injects a batch of blocks into the chain
+	insertReceipts   receiptChainInsertFn     // Injects a batch of blocks and their receipts into the chain
+	rollback         chainRollbackFn          // Removes a batch of recently added chain links
+	dropPeer         peerDropFn               // Drops a peer for misbehaving
 
 
 	// Status
 	// Status
 	synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing
 	synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing
@@ -156,41 +156,41 @@ type Downloader struct {
 }
 }
 
 
 // New creates a new downloader to fetch hashes and blocks from remote peers.
 // New creates a new downloader to fetch hashes and blocks from remote peers.
-func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn,
-	getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn,
-	commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn,
-	insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader {
+func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlockAndState blockAndStateCheckFn,
+	getHeader headerRetrievalFn, getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn,
+	headFastBlock headFastBlockRetrievalFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn,
+	insertBlocks blockChainInsertFn, insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader {
 
 
 	return &Downloader{
 	return &Downloader{
-		mode:            FullSync,
-		mux:             mux,
-		queue:           newQueue(stateDb),
-		peers:           newPeerSet(),
-		hasHeader:       hasHeader,
-		hasBlock:        hasBlock,
-		getHeader:       getHeader,
-		getBlock:        getBlock,
-		headHeader:      headHeader,
-		headBlock:       headBlock,
-		headFastBlock:   headFastBlock,
-		commitHeadBlock: commitHeadBlock,
-		getTd:           getTd,
-		insertHeaders:   insertHeaders,
-		insertBlocks:    insertBlocks,
-		insertReceipts:  insertReceipts,
-		rollback:        rollback,
-		dropPeer:        dropPeer,
-		newPeerCh:       make(chan *peer, 1),
-		hashCh:          make(chan dataPack, 1),
-		blockCh:         make(chan dataPack, 1),
-		headerCh:        make(chan dataPack, 1),
-		bodyCh:          make(chan dataPack, 1),
-		receiptCh:       make(chan dataPack, 1),
-		stateCh:         make(chan dataPack, 1),
-		blockWakeCh:     make(chan bool, 1),
-		bodyWakeCh:      make(chan bool, 1),
-		receiptWakeCh:   make(chan bool, 1),
-		stateWakeCh:     make(chan bool, 1),
+		mode:             FullSync,
+		mux:              mux,
+		queue:            newQueue(stateDb),
+		peers:            newPeerSet(),
+		hasHeader:        hasHeader,
+		hasBlockAndState: hasBlockAndState,
+		getHeader:        getHeader,
+		getBlock:         getBlock,
+		headHeader:       headHeader,
+		headBlock:        headBlock,
+		headFastBlock:    headFastBlock,
+		commitHeadBlock:  commitHeadBlock,
+		getTd:            getTd,
+		insertHeaders:    insertHeaders,
+		insertBlocks:     insertBlocks,
+		insertReceipts:   insertReceipts,
+		rollback:         rollback,
+		dropPeer:         dropPeer,
+		newPeerCh:        make(chan *peer, 1),
+		hashCh:           make(chan dataPack, 1),
+		blockCh:          make(chan dataPack, 1),
+		headerCh:         make(chan dataPack, 1),
+		bodyCh:           make(chan dataPack, 1),
+		receiptCh:        make(chan dataPack, 1),
+		stateCh:          make(chan dataPack, 1),
+		blockWakeCh:      make(chan bool, 1),
+		bodyWakeCh:       make(chan bool, 1),
+		receiptWakeCh:    make(chan bool, 1),
+		stateWakeCh:      make(chan bool, 1),
 	}
 	}
 }
 }
 
 
@@ -564,7 +564,7 @@ func (d *Downloader) findAncestor61(p *peer) (uint64, error) {
 			// Check if a common ancestor was found
 			// Check if a common ancestor was found
 			finished = true
 			finished = true
 			for i := len(hashes) - 1; i >= 0; i-- {
 			for i := len(hashes) - 1; i >= 0; i-- {
-				if d.hasBlock(hashes[i]) {
+				if d.hasBlockAndState(hashes[i]) {
 					number, hash = uint64(from)+uint64(i), hashes[i]
 					number, hash = uint64(from)+uint64(i), hashes[i]
 					break
 					break
 				}
 				}
@@ -620,11 +620,11 @@ func (d *Downloader) findAncestor61(p *peer) (uint64, error) {
 				arrived = true
 				arrived = true
 
 
 				// Modify the search interval based on the response
 				// Modify the search interval based on the response
-				block := d.getBlock(hashes[0])
-				if block == nil {
+				if !d.hasBlockAndState(hashes[0]) {
 					end = check
 					end = check
 					break
 					break
 				}
 				}
+				block := d.getBlock(hashes[0]) // this doesn't check state, hence the above explicit check
 				if block.NumberU64() != check {
 				if block.NumberU64() != check {
 					glog.V(logger.Debug).Infof("%v: non requested hash #%d [%x…], instead of #%d", p, block.NumberU64(), block.Hash().Bytes()[:4], check)
 					glog.V(logger.Debug).Infof("%v: non requested hash #%d [%x…], instead of #%d", p, block.NumberU64(), block.Hash().Bytes()[:4], check)
 					return 0, errBadPeer
 					return 0, errBadPeer
@@ -989,7 +989,7 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) {
 			// Check if a common ancestor was found
 			// Check if a common ancestor was found
 			finished = true
 			finished = true
 			for i := len(headers) - 1; i >= 0; i-- {
 			for i := len(headers) - 1; i >= 0; i-- {
-				if (d.mode != LightSync && d.hasBlock(headers[i].Hash())) || (d.mode == LightSync && d.hasHeader(headers[i].Hash())) {
+				if (d.mode != LightSync && d.hasBlockAndState(headers[i].Hash())) || (d.mode == LightSync && d.hasHeader(headers[i].Hash())) {
 					number, hash = headers[i].Number.Uint64(), headers[i].Hash()
 					number, hash = headers[i].Number.Uint64(), headers[i].Hash()
 					break
 					break
 				}
 				}
@@ -1045,7 +1045,7 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) {
 				arrived = true
 				arrived = true
 
 
 				// Modify the search interval based on the response
 				// Modify the search interval based on the response
-				if (d.mode == FullSync && !d.hasBlock(headers[0].Hash())) || (d.mode != FullSync && !d.hasHeader(headers[0].Hash())) {
+				if (d.mode == FullSync && !d.hasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.hasHeader(headers[0].Hash())) {
 					end = check
 					end = check
 					break
 					break
 				}
 				}

+ 14 - 3
eth/downloader/downloader_test.go

@@ -153,6 +153,8 @@ func newTester() *downloadTester {
 		peerChainTds: make(map[string]map[common.Hash]*big.Int),
 		peerChainTds: make(map[string]map[common.Hash]*big.Int),
 	}
 	}
 	tester.stateDb, _ = ethdb.NewMemDatabase()
 	tester.stateDb, _ = ethdb.NewMemDatabase()
+	tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00})
+
 	tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
 	tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
 		tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
 		tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
 		tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
 		tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
@@ -180,9 +182,14 @@ func (dl *downloadTester) hasHeader(hash common.Hash) bool {
 	return dl.getHeader(hash) != nil
 	return dl.getHeader(hash) != nil
 }
 }
 
 
-// hasBlock checks if a block is present in the testers canonical chain.
+// hasBlock checks if a block and associated state is present in the testers canonical chain.
 func (dl *downloadTester) hasBlock(hash common.Hash) bool {
 func (dl *downloadTester) hasBlock(hash common.Hash) bool {
-	return dl.getBlock(hash) != nil
+	block := dl.getBlock(hash)
+	if block == nil {
+		return false
+	}
+	_, err := dl.stateDb.Get(block.Root().Bytes())
+	return err == nil
 }
 }
 
 
 // getHeader retrieves a header from the testers canonical chain.
 // getHeader retrieves a header from the testers canonical chain.
@@ -295,8 +302,10 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
 	defer dl.lock.Unlock()
 	defer dl.lock.Unlock()
 
 
 	for i, block := range blocks {
 	for i, block := range blocks {
-		if _, ok := dl.ownBlocks[block.ParentHash()]; !ok {
+		if parent, ok := dl.ownBlocks[block.ParentHash()]; !ok {
 			return i, errors.New("unknown parent")
 			return i, errors.New("unknown parent")
+		} else if _, err := dl.stateDb.Get(parent.Root().Bytes()); err != nil {
+			return i, fmt.Errorf("unknown parent state %x: %v", parent.Root(), err)
 		}
 		}
 		if _, ok := dl.ownHeaders[block.Hash()]; !ok {
 		if _, ok := dl.ownHeaders[block.Hash()]; !ok {
 			dl.ownHashes = append(dl.ownHashes, block.Hash())
 			dl.ownHashes = append(dl.ownHashes, block.Hash())
@@ -1103,6 +1112,8 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
 }
 }
 
 
 // Tests that upon detecting an invalid header, the recent ones are rolled back
 // Tests that upon detecting an invalid header, the recent ones are rolled back
+// for various failure scenarios. Afterwards a full sync is attempted to make
+// sure no state was corrupted.
 func TestInvalidHeaderRollback63Fast(t *testing.T)  { testInvalidHeaderRollback(t, 63, FastSync) }
 func TestInvalidHeaderRollback63Fast(t *testing.T)  { testInvalidHeaderRollback(t, 63, FastSync) }
 func TestInvalidHeaderRollback64Fast(t *testing.T)  { testInvalidHeaderRollback(t, 64, FastSync) }
 func TestInvalidHeaderRollback64Fast(t *testing.T)  { testInvalidHeaderRollback(t, 64, FastSync) }
 func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) }
 func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) }

+ 2 - 2
eth/downloader/types.go

@@ -27,8 +27,8 @@ import (
 // headerCheckFn is a callback type for verifying a header's presence in the local chain.
 // headerCheckFn is a callback type for verifying a header's presence in the local chain.
 type headerCheckFn func(common.Hash) bool
 type headerCheckFn func(common.Hash) bool
 
 
-// blockCheckFn is a callback type for verifying a block's presence in the local chain.
-type blockCheckFn func(common.Hash) bool
+// blockAndStateCheckFn is a callback type for verifying block and associated states' presence in the local chain.
+type blockAndStateCheckFn func(common.Hash) bool
 
 
 // headerRetrievalFn is a callback type for retrieving a header from the local chain.
 // headerRetrievalFn is a callback type for retrieving a header from the local chain.
 type headerRetrievalFn func(common.Hash) *types.Header
 type headerRetrievalFn func(common.Hash) *types.Header

+ 4 - 3
eth/handler.go

@@ -138,9 +138,10 @@ func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool
 		return nil, errIncompatibleConfig
 		return nil, errIncompatibleConfig
 	}
 	}
 	// Construct the different synchronisation mechanisms
 	// Construct the different synchronisation mechanisms
-	manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader, blockchain.GetBlock,
-		blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, blockchain.GetTd,
-		blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer)
+	manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlockAndState, blockchain.GetHeader,
+		blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead,
+		blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback,
+		manager.removePeer)
 
 
 	validator := func(block *types.Block, parent *types.Block) error {
 	validator := func(block *types.Block, parent *types.Block) error {
 		return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)
 		return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)