Эх сурвалжийг харах

core, light: get rid of the dual mutexes, hard to reason with

Péter Szilágyi 6 жил өмнө
parent
commit
f25f776c9f

+ 27 - 47
core/blockchain.go

@@ -108,7 +108,6 @@ type BlockChain struct {
 	scope         event.SubscriptionScope
 	scope         event.SubscriptionScope
 	genesisBlock  *types.Block
 	genesisBlock  *types.Block
 
 
-	mu      sync.RWMutex // global mutex for locking chain operations
 	chainmu sync.RWMutex // blockchain insertion lock
 	chainmu sync.RWMutex // blockchain insertion lock
 	procmu  sync.RWMutex // block processor lock
 	procmu  sync.RWMutex // block processor lock
 
 
@@ -281,8 +280,8 @@ func (bc *BlockChain) loadLastState() error {
 func (bc *BlockChain) SetHead(head uint64) error {
 func (bc *BlockChain) SetHead(head uint64) error {
 	log.Warn("Rewinding blockchain", "target", head)
 	log.Warn("Rewinding blockchain", "target", head)
 
 
-	bc.mu.Lock()
-	defer bc.mu.Unlock()
+	bc.chainmu.Lock()
+	defer bc.chainmu.Unlock()
 
 
 	// Rewind the header chain, deleting all block bodies until then
 	// Rewind the header chain, deleting all block bodies until then
 	delFn := func(db rawdb.DatabaseDeleter, hash common.Hash, num uint64) {
 	delFn := func(db rawdb.DatabaseDeleter, hash common.Hash, num uint64) {
@@ -340,9 +339,9 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error {
 		return err
 		return err
 	}
 	}
 	// If all checks out, manually set the head block
 	// If all checks out, manually set the head block
-	bc.mu.Lock()
+	bc.chainmu.Lock()
 	bc.currentBlock.Store(block)
 	bc.currentBlock.Store(block)
-	bc.mu.Unlock()
+	bc.chainmu.Unlock()
 
 
 	log.Info("Committed new head block", "number", block.Number(), "hash", hash)
 	log.Info("Committed new head block", "number", block.Number(), "hash", hash)
 	return nil
 	return nil
@@ -420,8 +419,8 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
 	if err := bc.SetHead(0); err != nil {
 	if err := bc.SetHead(0); err != nil {
 		return err
 		return err
 	}
 	}
-	bc.mu.Lock()
-	defer bc.mu.Unlock()
+	bc.chainmu.Lock()
+	defer bc.chainmu.Unlock()
 
 
 	// Prepare the genesis block and reinitialise the chain
 	// Prepare the genesis block and reinitialise the chain
 	if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil {
 	if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil {
@@ -468,8 +467,8 @@ func (bc *BlockChain) Export(w io.Writer) error {
 
 
 // ExportN writes a subset of the active chain to the given writer.
 // ExportN writes a subset of the active chain to the given writer.
 func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
 func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
-	bc.mu.RLock()
-	defer bc.mu.RUnlock()
+	bc.chainmu.RLock()
+	defer bc.chainmu.RUnlock()
 
 
 	if first > last {
 	if first > last {
 		return fmt.Errorf("export failed: first (%d) is greater than last (%d)", first, last)
 		return fmt.Errorf("export failed: first (%d) is greater than last (%d)", first, last)
@@ -490,7 +489,6 @@ func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
 			reported = time.Now()
 			reported = time.Now()
 		}
 		}
 	}
 	}
-
 	return nil
 	return nil
 }
 }
 
 
@@ -756,8 +754,8 @@ const (
 // Rollback is designed to remove a chain of links from the database that aren't
 // Rollback is designed to remove a chain of links from the database that aren't
 // certain enough to be valid.
 // certain enough to be valid.
 func (bc *BlockChain) Rollback(chain []common.Hash) {
 func (bc *BlockChain) Rollback(chain []common.Hash) {
-	bc.mu.Lock()
-	defer bc.mu.Unlock()
+	bc.chainmu.Lock()
+	defer bc.chainmu.Unlock()
 
 
 	for i := len(chain) - 1; i >= 0; i-- {
 	for i := len(chain) - 1; i >= 0; i-- {
 		hash := chain[i]
 		hash := chain[i]
@@ -881,7 +879,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
 	}
 	}
 
 
 	// Update the head fast sync block if better
 	// Update the head fast sync block if better
-	bc.mu.Lock()
+	bc.chainmu.Lock()
 	head := blockChain[len(blockChain)-1]
 	head := blockChain[len(blockChain)-1]
 	if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case
 	if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case
 		currentFastBlock := bc.CurrentFastBlock()
 		currentFastBlock := bc.CurrentFastBlock()
@@ -890,7 +888,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
 			bc.currentFastBlock.Store(head)
 			bc.currentFastBlock.Store(head)
 		}
 		}
 	}
 	}
-	bc.mu.Unlock()
+	bc.chainmu.Unlock()
 
 
 	context := []interface{}{
 	context := []interface{}{
 		"count", stats.processed, "elapsed", common.PrettyDuration(time.Since(start)),
 		"count", stats.processed, "elapsed", common.PrettyDuration(time.Since(start)),
@@ -924,6 +922,15 @@ func (bc *BlockChain) WriteBlockWithoutState(block *types.Block, td *big.Int) (e
 
 
 // WriteBlockWithState writes the block and all associated state to the database.
 // WriteBlockWithState writes the block and all associated state to the database.
 func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
 func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
+	bc.chainmu.Lock()
+	defer bc.chainmu.Unlock()
+
+	return bc.writeBlockWithState(block, receipts, state)
+}
+
+// writeBlockWithState writes the block and all associated state to the database,
+// but is expects the chain mutex to be held.
+func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
 	bc.wg.Add(1)
 	bc.wg.Add(1)
 	defer bc.wg.Done()
 	defer bc.wg.Done()
 
 
@@ -933,9 +940,6 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.
 		return NonStatTy, consensus.ErrUnknownAncestor
 		return NonStatTy, consensus.ErrUnknownAncestor
 	}
 	}
 	// Make sure no inconsistent state is leaked during insertion
 	// Make sure no inconsistent state is leaked during insertion
-	bc.mu.Lock()
-	defer bc.mu.Unlock()
-
 	currentBlock := bc.CurrentBlock()
 	currentBlock := bc.CurrentBlock()
 	localTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64())
 	localTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64())
 	externTd := new(big.Int).Add(block.Difficulty(), ptd)
 	externTd := new(big.Int).Add(block.Difficulty(), ptd)
@@ -1212,7 +1216,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, []
 		proctime := time.Since(start)
 		proctime := time.Since(start)
 
 
 		// Write the block to the chain and get the status.
 		// Write the block to the chain and get the status.
-		status, err := bc.WriteBlockWithState(block, receipts, state)
+		status, err := bc.writeBlockWithState(block, receipts, state)
 		t3 := time.Now()
 		t3 := time.Now()
 		if err != nil {
 		if err != nil {
 			return it.index, events, coalescedLogs, err
 			return it.index, events, coalescedLogs, err
@@ -1281,7 +1285,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, []
 func (bc *BlockChain) insertSidechain(it *insertIterator) (int, []interface{}, []*types.Log, error) {
 func (bc *BlockChain) insertSidechain(it *insertIterator) (int, []interface{}, []*types.Log, error) {
 	var (
 	var (
 		externTd *big.Int
 		externTd *big.Int
-		current  = bc.CurrentBlock().NumberU64()
+		current  = bc.CurrentBlock()
 	)
 	)
 	// The first sidechain block error is already verified to be ErrPrunedAncestor.
 	// The first sidechain block error is already verified to be ErrPrunedAncestor.
 	// Since we don't import them here, we expect ErrUnknownAncestor for the remaining
 	// Since we don't import them here, we expect ErrUnknownAncestor for the remaining
@@ -1290,7 +1294,7 @@ func (bc *BlockChain) insertSidechain(it *insertIterator) (int, []interface{}, [
 	block, err := it.current(), consensus.ErrPrunedAncestor
 	block, err := it.current(), consensus.ErrPrunedAncestor
 	for ; block != nil && (err == consensus.ErrPrunedAncestor); block, err = it.next() {
 	for ; block != nil && (err == consensus.ErrPrunedAncestor); block, err = it.next() {
 		// Check the canonical state root for that number
 		// Check the canonical state root for that number
-		if number := block.NumberU64(); current >= number {
+		if number := block.NumberU64(); current.NumberU64() >= number {
 			if canonical := bc.GetBlockByNumber(number); canonical != nil && canonical.Root() == block.Root() {
 			if canonical := bc.GetBlockByNumber(number); canonical != nil && canonical.Root() == block.Root() {
 				// This is most likely a shadow-state attack. When a fork is imported into the
 				// This is most likely a shadow-state attack. When a fork is imported into the
 				// database, and it eventually reaches a block height which is not pruned, we
 				// database, and it eventually reaches a block height which is not pruned, we
@@ -1329,7 +1333,7 @@ func (bc *BlockChain) insertSidechain(it *insertIterator) (int, []interface{}, [
 	//
 	//
 	// If the externTd was larger than our local TD, we now need to reimport the previous
 	// If the externTd was larger than our local TD, we now need to reimport the previous
 	// blocks to regenerate the required state
 	// blocks to regenerate the required state
-	localTd := bc.GetTd(bc.CurrentBlock().Hash(), current)
+	localTd := bc.GetTd(current.Hash(), current.NumberU64())
 	if localTd.Cmp(externTd) > 0 {
 	if localTd.Cmp(externTd) > 0 {
 		log.Info("Sidechain written to disk", "start", it.first().NumberU64(), "end", it.previous().NumberU64(), "sidetd", externTd, "localtd", localTd)
 		log.Info("Sidechain written to disk", "start", it.first().NumberU64(), "end", it.previous().NumberU64(), "sidetd", externTd, "localtd", localTd)
 		return it.index, nil, nil, err
 		return it.index, nil, nil, err
@@ -1597,36 +1601,12 @@ func (bc *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int) (i
 	defer bc.wg.Done()
 	defer bc.wg.Done()
 
 
 	whFunc := func(header *types.Header) error {
 	whFunc := func(header *types.Header) error {
-		bc.mu.Lock()
-		defer bc.mu.Unlock()
-
 		_, err := bc.hc.WriteHeader(header)
 		_, err := bc.hc.WriteHeader(header)
 		return err
 		return err
 	}
 	}
-
 	return bc.hc.InsertHeaderChain(chain, whFunc, start)
 	return bc.hc.InsertHeaderChain(chain, whFunc, start)
 }
 }
 
 
-// writeHeader writes a header into the local chain, given that its parent is
-// already known. If the total difficulty of the newly inserted header becomes
-// greater than the current known TD, the canonical chain is re-routed.
-//
-// Note: This method is not concurrent-safe with inserting blocks simultaneously
-// into the chain, as side effects caused by reorganisations cannot be emulated
-// without the real blocks. Hence, writing headers directly should only be done
-// in two scenarios: pure-header mode of operation (light clients), or properly
-// separated header/block phases (non-archive clients).
-func (bc *BlockChain) writeHeader(header *types.Header) error {
-	bc.wg.Add(1)
-	defer bc.wg.Done()
-
-	bc.mu.Lock()
-	defer bc.mu.Unlock()
-
-	_, err := bc.hc.WriteHeader(header)
-	return err
-}
-
 // CurrentHeader retrieves the current head header of the canonical chain. The
 // CurrentHeader retrieves the current head header of the canonical chain. The
 // header is retrieved from the HeaderChain's internal cache.
 // header is retrieved from the HeaderChain's internal cache.
 func (bc *BlockChain) CurrentHeader() *types.Header {
 func (bc *BlockChain) CurrentHeader() *types.Header {
@@ -1675,8 +1655,8 @@ func (bc *BlockChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []com
 //
 //
 // Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
 // Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
 func (bc *BlockChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
 func (bc *BlockChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
-	bc.chainmu.Lock()
-	defer bc.chainmu.Unlock()
+	bc.chainmu.RLock()
+	defer bc.chainmu.RUnlock()
 
 
 	return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical)
 	return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical)
 }
 }

+ 4 - 4
core/blockchain_test.go

@@ -162,11 +162,11 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
 			blockchain.reportBlock(block, receipts, err)
 			blockchain.reportBlock(block, receipts, err)
 			return err
 			return err
 		}
 		}
-		blockchain.mu.Lock()
+		blockchain.chainmu.Lock()
 		rawdb.WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
 		rawdb.WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
 		rawdb.WriteBlock(blockchain.db, block)
 		rawdb.WriteBlock(blockchain.db, block)
 		statedb.Commit(false)
 		statedb.Commit(false)
-		blockchain.mu.Unlock()
+		blockchain.chainmu.Unlock()
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -180,10 +180,10 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error
 			return err
 			return err
 		}
 		}
 		// Manually insert the header into the database, but don't reorganise (allows subsequent testing)
 		// Manually insert the header into the database, but don't reorganise (allows subsequent testing)
-		blockchain.mu.Lock()
+		blockchain.chainmu.Lock()
 		rawdb.WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
 		rawdb.WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
 		rawdb.WriteHeader(blockchain.db, header)
 		rawdb.WriteHeader(blockchain.db, header)
-		blockchain.mu.Unlock()
+		blockchain.chainmu.Unlock()
 	}
 	}
 	return nil
 	return nil
 }
 }

+ 11 - 18
light/lightchain.go

@@ -57,7 +57,6 @@ type LightChain struct {
 	scope         event.SubscriptionScope
 	scope         event.SubscriptionScope
 	genesisBlock  *types.Block
 	genesisBlock  *types.Block
 
 
-	mu      sync.RWMutex
 	chainmu sync.RWMutex
 	chainmu sync.RWMutex
 
 
 	bodyCache    *lru.Cache // Cache for the most recent block bodies
 	bodyCache    *lru.Cache // Cache for the most recent block bodies
@@ -165,8 +164,8 @@ func (self *LightChain) loadLastState() error {
 // SetHead rewinds the local chain to a new head. Everything above the new
 // SetHead rewinds the local chain to a new head. Everything above the new
 // head will be deleted and the new one set.
 // head will be deleted and the new one set.
 func (bc *LightChain) SetHead(head uint64) {
 func (bc *LightChain) SetHead(head uint64) {
-	bc.mu.Lock()
-	defer bc.mu.Unlock()
+	bc.chainmu.Lock()
+	defer bc.chainmu.Unlock()
 
 
 	bc.hc.SetHead(head, nil)
 	bc.hc.SetHead(head, nil)
 	bc.loadLastState()
 	bc.loadLastState()
@@ -188,8 +187,8 @@ func (bc *LightChain) ResetWithGenesisBlock(genesis *types.Block) {
 	// Dump the entire block chain and purge the caches
 	// Dump the entire block chain and purge the caches
 	bc.SetHead(0)
 	bc.SetHead(0)
 
 
-	bc.mu.Lock()
-	defer bc.mu.Unlock()
+	bc.chainmu.Lock()
+	defer bc.chainmu.Unlock()
 
 
 	// Prepare the genesis block and reinitialise the chain
 	// Prepare the genesis block and reinitialise the chain
 	rawdb.WriteTd(bc.chainDb, genesis.Hash(), genesis.NumberU64(), genesis.Difficulty())
 	rawdb.WriteTd(bc.chainDb, genesis.Hash(), genesis.NumberU64(), genesis.Difficulty())
@@ -315,8 +314,8 @@ func (bc *LightChain) Stop() {
 // Rollback is designed to remove a chain of links from the database that aren't
 // Rollback is designed to remove a chain of links from the database that aren't
 // certain enough to be valid.
 // certain enough to be valid.
 func (self *LightChain) Rollback(chain []common.Hash) {
 func (self *LightChain) Rollback(chain []common.Hash) {
-	self.mu.Lock()
-	defer self.mu.Unlock()
+	self.chainmu.Lock()
+	defer self.chainmu.Unlock()
 
 
 	for i := len(chain) - 1; i >= 0; i-- {
 	for i := len(chain) - 1; i >= 0; i-- {
 		hash := chain[i]
 		hash := chain[i]
@@ -362,19 +361,13 @@ func (self *LightChain) InsertHeaderChain(chain []*types.Header, checkFreq int)
 
 
 	// Make sure only one thread manipulates the chain at once
 	// Make sure only one thread manipulates the chain at once
 	self.chainmu.Lock()
 	self.chainmu.Lock()
-	defer func() {
-		self.chainmu.Unlock()
-		time.Sleep(time.Millisecond * 10) // ugly hack; do not hog chain lock in case syncing is CPU-limited by validation
-	}()
+	defer self.chainmu.Unlock()
 
 
 	self.wg.Add(1)
 	self.wg.Add(1)
 	defer self.wg.Done()
 	defer self.wg.Done()
 
 
 	var events []interface{}
 	var events []interface{}
 	whFunc := func(header *types.Header) error {
 	whFunc := func(header *types.Header) error {
-		self.mu.Lock()
-		defer self.mu.Unlock()
-
 		status, err := self.hc.WriteHeader(header)
 		status, err := self.hc.WriteHeader(header)
 
 
 		switch status {
 		switch status {
@@ -441,8 +434,8 @@ func (self *LightChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []c
 //
 //
 // Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
 // Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
 func (bc *LightChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
 func (bc *LightChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
-	bc.chainmu.Lock()
-	defer bc.chainmu.Unlock()
+	bc.chainmu.RLock()
+	defer bc.chainmu.RUnlock()
 
 
 	return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical)
 	return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical)
 }
 }
@@ -483,8 +476,8 @@ func (self *LightChain) SyncCht(ctx context.Context) bool {
 	}
 	}
 	// Retrieve the latest useful header and update to it
 	// Retrieve the latest useful header and update to it
 	if header, err := GetHeaderByNumber(ctx, self.odr, latest); header != nil && err == nil {
 	if header, err := GetHeaderByNumber(ctx, self.odr, latest); header != nil && err == nil {
-		self.mu.Lock()
-		defer self.mu.Unlock()
+		self.chainmu.Lock()
+		defer self.chainmu.Unlock()
 
 
 		// Ensure the chain didn't move past the latest block while retrieving it
 		// Ensure the chain didn't move past the latest block while retrieving it
 		if self.hc.CurrentHeader().Number.Uint64() < header.Number.Uint64() {
 		if self.hc.CurrentHeader().Number.Uint64() < header.Number.Uint64() {

+ 2 - 2
light/lightchain_test.go

@@ -122,10 +122,10 @@ func testHeaderChainImport(chain []*types.Header, lightchain *LightChain) error
 			return err
 			return err
 		}
 		}
 		// Manually insert the header into the database, but don't reorganize (allows subsequent testing)
 		// Manually insert the header into the database, but don't reorganize (allows subsequent testing)
-		lightchain.mu.Lock()
+		lightchain.chainmu.Lock()
 		rawdb.WriteTd(lightchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, lightchain.GetTdByHash(header.ParentHash)))
 		rawdb.WriteTd(lightchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, lightchain.GetTdByHash(header.ParentHash)))
 		rawdb.WriteHeader(lightchain.chainDb, header)
 		rawdb.WriteHeader(lightchain.chainDb, header)
-		lightchain.mu.Unlock()
+		lightchain.chainmu.Unlock()
 	}
 	}
 	return nil
 	return nil
 }
 }