Selaa lähdekoodia

[R4R]implement diff sync (#376)

* implement block process part of light sync

* add difflayer protocol

* handle difflayer and refine light processor

* add testcase for diff protocol

* make it faster

* allow validator to light sync

* change into diff sync

* ligth sync: download difflayer (#2)

* ligth sync: download difflayer

Signed-off-by: kyrie-yl <lei.y@binance.com>

* download diff layer: fix according to the comments

Signed-off-by: kyrie-yl <lei.y@binance.com>

* download diff layer: update

Signed-off-by: kyrie-yl <lei.y@binance.com>

* download diff layer: fix accroding comments

Signed-off-by: kyrie-yl <lei.y@binance.com>

Co-authored-by: kyrie-yl <lei.y@binance.com>

* update light sync to diff sync

* raise the max diff limit

* add switcher of snap protocol

* fix test case

* make commit concurrently

* remove peer for diff cache when peer closed

* consensus tuning

* add test code

* remove extra message

* fix testcase and lint

make diff block configable

wait code write

fix testcase

resolve comments

resolve comment

* resolve comments

* resolve comments

* resolve comment

* fix mistake

Co-authored-by: kyrie-yl <83150977+kyrie-yl@users.noreply.github.com>
Co-authored-by: kyrie-yl <lei.y@binance.com>
zjubfd 4 vuotta sitten
vanhempi
commit
1ded097733
71 muutettua tiedostoa jossa 3352 lisäystä ja 394 poistoa
  1. 2 2
      cmd/evm/internal/t8ntool/execution.go
  2. 1 1
      cmd/faucet/faucet.go
  3. 4 0
      cmd/geth/main.go
  4. 1 0
      cmd/geth/usage.go
  5. 40 1
      cmd/utils/flags.go
  6. 13 1
      common/gopool/pool.go
  7. 1 0
      consensus/consensus.go
  8. 16 0
      consensus/parlia/parlia.go
  9. 1 1
      consensus/parlia/ramanujanfork.go
  10. 1 1
      consensus/parlia/snapshot.go
  11. 0 3
      core/block_validator.go
  12. 430 35
      core/blockchain.go
  13. 263 0
      core/blockchain_diff_test.go
  14. 3 3
      core/blockchain_test.go
  15. 1 1
      core/chain_makers.go
  16. 38 0
      core/rawdb/accessors_chain.go
  17. 26 0
      core/rawdb/database.go
  18. 3 57
      core/rawdb/freezer_table_test.go
  19. 8 0
      core/rawdb/schema.go
  20. 8 0
      core/rawdb/table.go
  21. 0 1
      core/state/database.go
  22. 1 1
      core/state/journal.go
  23. 4 4
      core/state/snapshot/disklayer_test.go
  24. 43 43
      core/state/snapshot/iterator_test.go
  25. 31 2
      core/state/snapshot/snapshot.go
  26. 12 12
      core/state/snapshot/snapshot_test.go
  27. 6 7
      core/state/state_object.go
  28. 1 1
      core/state/state_test.go
  29. 273 36
      core/state/statedb.go
  30. 7 7
      core/state/statedb_test.go
  31. 1 1
      core/state/sync_test.go
  32. 0 9
      core/state_prefetcher.go
  33. 322 7
      core/state_processor.go
  34. 1 1
      core/types.go
  35. 88 1
      core/types/block.go
  36. 1 1
      core/vm/contracts_lightclient_test.go
  37. 1 1
      core/vm/lightclient/types.go
  38. 15 3
      eth/backend.go
  39. 49 7
      eth/downloader/downloader.go
  40. 2 2
      eth/downloader/downloader_test.go
  41. 9 3
      eth/ethconfig/config.go
  42. 18 0
      eth/ethconfig/gen_config.go
  43. 12 1
      eth/fetcher/block_fetcher.go
  44. 18 18
      eth/fetcher/block_fetcher_test.go
  45. 38 3
      eth/handler.go
  46. 87 0
      eth/handler_diff.go
  47. 203 0
      eth/handler_diff_test.go
  48. 11 2
      eth/handler_eth.go
  49. 22 0
      eth/peer.go
  50. 86 1
      eth/peerset.go
  51. 32 0
      eth/protocols/diff/discovery.go
  52. 180 0
      eth/protocols/diff/handler.go
  53. 192 0
      eth/protocols/diff/handler_test.go
  54. 82 0
      eth/protocols/diff/handshake.go
  55. 107 0
      eth/protocols/diff/peer.go
  56. 61 0
      eth/protocols/diff/peer_test.go
  57. 122 0
      eth/protocols/diff/protocol.go
  58. 131 0
      eth/protocols/diff/protocol_test.go
  59. 161 0
      eth/protocols/diff/tracker.go
  60. 2 2
      eth/state_accessor.go
  61. 1 1
      eth/tracers/tracers_test.go
  62. 2 68
      ethclient/ethclient_test.go
  63. 6 0
      ethdb/database.go
  64. 2 2
      les/fetcher.go
  65. 1 1
      les/peer.go
  66. 3 9
      light/trie.go
  67. 0 19
      miner/worker.go
  68. 3 0
      node/config.go
  69. 41 0
      node/node.go
  70. 0 10
      rlp/typecache.go
  71. 1 1
      tests/state_test_util.go

+ 2 - 2
cmd/evm/internal/t8ntool/execution.go

@@ -223,7 +223,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig,
 		statedb.AddBalance(pre.Env.Coinbase, minerReward)
 	}
 	// Commit block
-	root, err := statedb.Commit(chainConfig.IsEIP158(vmContext.BlockNumber))
+	root, _, err := statedb.Commit(chainConfig.IsEIP158(vmContext.BlockNumber))
 	if err != nil {
 		fmt.Fprintf(os.Stderr, "Could not commit state: %v", err)
 		return nil, nil, NewError(ErrorEVM, fmt.Errorf("could not commit state: %v", err))
@@ -252,7 +252,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB
 		}
 	}
 	// Commit and re-open to start with a clean state.
-	root, _ := statedb.Commit(false)
+	root, _, _ := statedb.Commit(false)
 	statedb, _ = state.New(root, sdb, nil)
 	return statedb
 }

+ 1 - 1
cmd/faucet/faucet.go

@@ -139,7 +139,7 @@ func main() {
 		log.Crit("Length of bep2eContracts, bep2eSymbols, bep2eAmounts mismatch")
 	}
 
-	bep2eInfos := make(map[string]bep2eInfo, 0)
+	bep2eInfos := make(map[string]bep2eInfo, len(symbols))
 	for idx, s := range symbols {
 		n, ok := big.NewInt(0).SetString(bep2eNumAmounts[idx], 10)
 		if !ok {

+ 4 - 0
cmd/geth/main.go

@@ -65,6 +65,8 @@ var (
 		utils.ExternalSignerFlag,
 		utils.NoUSBFlag,
 		utils.DirectBroadcastFlag,
+		utils.DisableSnapProtocolFlag,
+		utils.DiffSyncFlag,
 		utils.RangeLimitFlag,
 		utils.USBFlag,
 		utils.SmartCardDaemonPathFlag,
@@ -114,6 +116,8 @@ var (
 		utils.CacheGCFlag,
 		utils.CacheSnapshotFlag,
 		utils.CachePreimagesFlag,
+		utils.PersistDiffFlag,
+		utils.DiffBlockFlag,
 		utils.ListenPortFlag,
 		utils.MaxPeersFlag,
 		utils.MaxPendingPeersFlag,

+ 1 - 0
cmd/geth/usage.go

@@ -40,6 +40,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{
 			utils.KeyStoreDirFlag,
 			utils.NoUSBFlag,
 			utils.DirectBroadcastFlag,
+			utils.DisableSnapProtocolFlag,
 			utils.RangeLimitFlag,
 			utils.SmartCardDaemonPathFlag,
 			utils.NetworkIdFlag,

+ 40 - 1
cmd/utils/flags.go

@@ -117,6 +117,15 @@ var (
 		Name:  "directbroadcast",
 		Usage: "Enable directly broadcast mined block to all peers",
 	}
+	DisableSnapProtocolFlag = cli.BoolFlag{
+		Name:  "disablesnapprotocol",
+		Usage: "Disable snap protocol",
+	}
+	DiffSyncFlag = cli.BoolFlag{
+		Name: "diffsync",
+		Usage: "Enable diffy sync, Please note that enable diffsync will improve the syncing speed, " +
+			"but will degrade the security to light client level",
+	}
 	RangeLimitFlag = cli.BoolFlag{
 		Name:  "rangelimit",
 		Usage: "Enable 5000 blocks limit for range query",
@@ -125,6 +134,10 @@ var (
 		Name:  "datadir.ancient",
 		Usage: "Data directory for ancient chain segments (default = inside chaindata)",
 	}
+	DiffFlag = DirectoryFlag{
+		Name:  "datadir.diff",
+		Usage: "Data directory for difflayer segments (default = inside chaindata)",
+	}
 	MinFreeDiskSpaceFlag = DirectoryFlag{
 		Name:  "datadir.minfreedisk",
 		Usage: "Minimum free disk space in MB, once reached triggers auto shut down (default = --cache.gc converted to MB, 0 = disabled)",
@@ -425,6 +438,15 @@ var (
 		Name:  "cache.preimages",
 		Usage: "Enable recording the SHA3/keccak preimages of trie keys",
 	}
+	PersistDiffFlag = cli.BoolFlag{
+		Name:  "persistdiff",
+		Usage: "Enable persistence of the diff layer",
+	}
+	DiffBlockFlag = cli.Uint64Flag{
+		Name:  "diffblock",
+		Usage: "The number of blocks should be persisted in db (default = 864000 )",
+		Value: uint64(864000),
+	}
 	// Miner settings
 	MiningEnabledFlag = cli.BoolFlag{
 		Name:  "mine",
@@ -1271,6 +1293,9 @@ func SetNodeConfig(ctx *cli.Context, cfg *node.Config) {
 	if ctx.GlobalIsSet(DirectBroadcastFlag.Name) {
 		cfg.DirectBroadcast = ctx.GlobalBool(DirectBroadcastFlag.Name)
 	}
+	if ctx.GlobalIsSet(DisableSnapProtocolFlag.Name) {
+		cfg.DisableSnapProtocol = ctx.GlobalBool(DisableSnapProtocolFlag.Name)
+	}
 	if ctx.GlobalIsSet(RangeLimitFlag.Name) {
 		cfg.RangeLimit = ctx.GlobalBool(RangeLimitFlag.Name)
 	}
@@ -1564,7 +1589,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) {
 	if ctx.GlobalIsSet(AncientFlag.Name) {
 		cfg.DatabaseFreezer = ctx.GlobalString(AncientFlag.Name)
 	}
-
+	if ctx.GlobalIsSet(DiffFlag.Name) {
+		cfg.DatabaseDiff = ctx.GlobalString(DiffFlag.Name)
+	}
+	if ctx.GlobalIsSet(PersistDiffFlag.Name) {
+		cfg.PersistDiff = ctx.GlobalBool(PersistDiffFlag.Name)
+	}
+	if ctx.GlobalIsSet(DiffBlockFlag.Name) {
+		cfg.DiffBlock = ctx.GlobalUint64(DiffBlockFlag.Name)
+	}
 	if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" {
 		Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name)
 	}
@@ -1574,6 +1607,12 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) {
 	if ctx.GlobalIsSet(DirectBroadcastFlag.Name) {
 		cfg.DirectBroadcast = ctx.GlobalBool(DirectBroadcastFlag.Name)
 	}
+	if ctx.GlobalIsSet(DisableSnapProtocolFlag.Name) {
+		cfg.DisableSnapProtocol = ctx.GlobalBool(DisableSnapProtocolFlag.Name)
+	}
+	if ctx.GlobalIsSet(DiffSyncFlag.Name) {
+		cfg.DiffSync = ctx.GlobalBool(DiffSyncFlag.Name)
+	}
 	if ctx.GlobalIsSet(RangeLimitFlag.Name) {
 		cfg.RangeLimit = ctx.GlobalBool(RangeLimitFlag.Name)
 	}

+ 13 - 1
common/gopool/pool.go

@@ -1,6 +1,7 @@
 package gopool
 
 import (
+	"runtime"
 	"time"
 
 	"github.com/panjf2000/ants/v2"
@@ -8,7 +9,8 @@ import (
 
 var (
 	// Init a instance pool when importing ants.
-	defaultPool, _ = ants.NewPool(ants.DefaultAntsPoolSize, ants.WithExpiryDuration(10*time.Second))
+	defaultPool, _   = ants.NewPool(ants.DefaultAntsPoolSize, ants.WithExpiryDuration(10*time.Second))
+	minNumberPerTask = 5
 )
 
 // Logger is used for logging formatted messages.
@@ -46,3 +48,13 @@ func Release() {
 func Reboot() {
 	defaultPool.Reboot()
 }
+
+func Threads(tasks int) int {
+	threads := tasks / minNumberPerTask
+	if threads > runtime.NumCPU() {
+		threads = runtime.NumCPU()
+	} else if threads == 0 {
+		threads = 1
+	}
+	return threads
+}

+ 1 - 0
consensus/consensus.go

@@ -141,4 +141,5 @@ type PoSA interface {
 	IsSystemContract(to *common.Address) bool
 	EnoughDistance(chain ChainReader, header *types.Header) bool
 	IsLocalBlock(header *types.Header) bool
+	AllowLightProcess(chain ChainReader, currentHeader *types.Header) bool
 }

+ 16 - 0
consensus/parlia/parlia.go

@@ -799,6 +799,11 @@ func (p *Parlia) Delay(chain consensus.ChainReader, header *types.Header) *time.
 		return nil
 	}
 	delay := p.delayForRamanujanFork(snap, header)
+	// The blocking time should be no more than half of period
+	half := time.Duration(p.config.Period) * time.Second / 2
+	if delay > half {
+		delay = half
+	}
 	return &delay
 }
 
@@ -882,6 +887,17 @@ func (p *Parlia) EnoughDistance(chain consensus.ChainReader, header *types.Heade
 	return snap.enoughDistance(p.val, header)
 }
 
+func (p *Parlia) AllowLightProcess(chain consensus.ChainReader, currentHeader *types.Header) bool {
+	snap, err := p.snapshot(chain, currentHeader.Number.Uint64()-1, currentHeader.ParentHash, nil)
+	if err != nil {
+		return true
+	}
+
+	idx := snap.indexOfVal(p.val)
+	// validator is not allowed to diff sync
+	return idx < 0
+}
+
 func (p *Parlia) IsLocalBlock(header *types.Header) bool {
 	return p.val == header.Coinbase
 }

+ 1 - 1
consensus/parlia/ramanujanfork.go

@@ -21,7 +21,7 @@ func (p *Parlia) delayForRamanujanFork(snap *Snapshot, header *types.Header) tim
 	if header.Difficulty.Cmp(diffNoTurn) == 0 {
 		// It's not our turn explicitly to sign, delay it a bit
 		wiggle := time.Duration(len(snap.Validators)/2+1) * wiggleTimeBeforeFork
-		delay += time.Duration(fixedBackOffTimeBeforeFork) + time.Duration(rand.Int63n(int64(wiggle)))
+		delay += fixedBackOffTimeBeforeFork + time.Duration(rand.Int63n(int64(wiggle)))
 	}
 	return delay
 }

+ 1 - 1
consensus/parlia/snapshot.go

@@ -256,7 +256,7 @@ func (s *Snapshot) enoughDistance(validator common.Address, header *types.Header
 	if validator == header.Coinbase {
 		return false
 	}
-	offset := (int64(s.Number) + 1) % int64(validatorNum)
+	offset := (int64(s.Number) + 1) % validatorNum
 	if int64(idx) >= offset {
 		return int64(idx)-offset >= validatorNum-2
 	} else {

+ 0 - 3
core/block_validator.go

@@ -17,9 +17,7 @@
 package core
 
 import (
-	"encoding/json"
 	"fmt"
-	"os"
 
 	"github.com/ethereum/go-ethereum/consensus"
 	"github.com/ethereum/go-ethereum/core/state"
@@ -133,7 +131,6 @@ func (v *BlockValidator) ValidateState(block *types.Block, statedb *state.StateD
 		},
 		func() error {
 			if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); header.Root != root {
-				statedb.IterativeDump(true, true, true, json.NewEncoder(os.Stdout))
 				return fmt.Errorf("invalid merkle root (remote: %x local: %x)", header.Root, root)
 			} else {
 				return nil

+ 430 - 35
core/blockchain.go

@@ -80,14 +80,22 @@ var (
 )
 
 const (
-	bodyCacheLimit      = 256
-	blockCacheLimit     = 256
-	receiptsCacheLimit  = 10000
-	txLookupCacheLimit  = 1024
-	maxFutureBlocks     = 256
-	maxTimeFutureBlocks = 30
-	badBlockLimit       = 10
-	maxBeyondBlocks     = 2048
+	bodyCacheLimit         = 256
+	blockCacheLimit        = 256
+	diffLayerCacheLimit    = 1024
+	diffLayerRLPCacheLimit = 256
+	receiptsCacheLimit     = 10000
+	txLookupCacheLimit     = 1024
+	maxFutureBlocks        = 256
+	maxTimeFutureBlocks    = 30
+	maxBeyondBlocks        = 2048
+
+	diffLayerFreezerRecheckInterval = 3 * time.Second
+	diffLayerPruneRecheckInterval   = 1 * time.Second // The interval to prune unverified diff layers
+	maxDiffQueueDist                = 2048            // Maximum allowed distance from the chain head to queue diffLayers
+	maxDiffLimit                    = 2048            // Maximum number of unique diff layers a peer may have responded
+	maxDiffForkDist                 = 11              // Maximum allowed backward distance from the chain head
+	maxDiffLimitForBroadcast        = 128             // Maximum number of unique diff layers a peer may have broadcasted
 
 	// BlockChainVersion ensures that an incompatible database forces a resync from scratch.
 	//
@@ -131,6 +139,11 @@ type CacheConfig struct {
 	SnapshotWait bool // Wait for snapshot construction on startup. TODO(karalabe): This is a dirty hack for testing, nuke it
 }
 
+// To avoid cycle import
+type PeerIDer interface {
+	ID() string
+}
+
 // defaultCacheConfig are the default caching values if none are specified by the
 // user (also used during testing).
 var defaultCacheConfig = &CacheConfig{
@@ -142,6 +155,8 @@ var defaultCacheConfig = &CacheConfig{
 	SnapshotWait:   true,
 }
 
+type BlockChainOption func(*BlockChain) *BlockChain
+
 // BlockChain represents the canonical chain given a database with a genesis
 // block. The Blockchain manages chain imports, reverts, chain reorganisations.
 //
@@ -196,6 +211,21 @@ type BlockChain struct {
 	txLookupCache *lru.Cache     // Cache for the most recent transaction lookup data.
 	futureBlocks  *lru.Cache     // future blocks are blocks added for later processing
 
+	// trusted diff layers
+	diffLayerCache             *lru.Cache   // Cache for the diffLayers
+	diffLayerRLPCache          *lru.Cache   // Cache for the rlp encoded diffLayers
+	diffQueue                  *prque.Prque // A Priority queue to store recent diff layer
+	diffQueueBuffer            chan *types.DiffLayer
+	diffLayerFreezerBlockLimit uint64
+
+	// untrusted diff layers
+	diffMux               sync.RWMutex
+	blockHashToDiffLayers map[common.Hash]map[common.Hash]*types.DiffLayer // map[blockHash] map[DiffHash]Diff
+	diffHashToBlockHash   map[common.Hash]common.Hash                      // map[diffHash]blockHash
+	diffHashToPeers       map[common.Hash]map[string]struct{}              // map[diffHash]map[pid]
+	diffNumToBlockHashes  map[uint64]map[common.Hash]struct{}              // map[number]map[blockHash]
+	diffPeersToDiffHashes map[string]map[common.Hash]struct{}              // map[pid]map[diffHash]
+
 	quit          chan struct{}  // blockchain quit channel
 	wg            sync.WaitGroup // chain processing wait group for shutting down
 	running       int32          // 0 if chain is running, 1 when stopped
@@ -213,12 +243,15 @@ type BlockChain struct {
 // NewBlockChain returns a fully initialised block chain using information
 // available in the database. It initialises the default Ethereum Validator and
 // Processor.
-func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config, shouldPreserve func(block *types.Block) bool, txLookupLimit *uint64) (*BlockChain, error) {
+func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine,
+	vmConfig vm.Config, shouldPreserve func(block *types.Block) bool, txLookupLimit *uint64,
+	options ...BlockChainOption) (*BlockChain, error) {
 	if cacheConfig == nil {
 		cacheConfig = defaultCacheConfig
 	}
 	if cacheConfig.TriesInMemory != 128 {
-		log.Warn("TriesInMemory isn't the default value(128), you need specify exact same TriesInMemory when prune data", "triesInMemory", cacheConfig.TriesInMemory)
+		log.Warn("TriesInMemory isn't the default value(128), you need specify exact same TriesInMemory when prune data",
+			"triesInMemory", cacheConfig.TriesInMemory)
 	}
 	bodyCache, _ := lru.New(bodyCacheLimit)
 	bodyRLPCache, _ := lru.New(bodyCacheLimit)
@@ -226,6 +259,8 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
 	blockCache, _ := lru.New(blockCacheLimit)
 	txLookupCache, _ := lru.New(txLookupCacheLimit)
 	futureBlocks, _ := lru.New(maxFutureBlocks)
+	diffLayerCache, _ := lru.New(diffLayerCacheLimit)
+	diffLayerRLPCache, _ := lru.New(diffLayerRLPCacheLimit)
 
 	bc := &BlockChain{
 		chainConfig: chainConfig,
@@ -237,17 +272,26 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
 			Journal:   cacheConfig.TrieCleanJournal,
 			Preimages: cacheConfig.Preimages,
 		}),
-		triesInMemory:  cacheConfig.TriesInMemory,
-		quit:           make(chan struct{}),
-		shouldPreserve: shouldPreserve,
-		bodyCache:      bodyCache,
-		bodyRLPCache:   bodyRLPCache,
-		receiptsCache:  receiptsCache,
-		blockCache:     blockCache,
-		txLookupCache:  txLookupCache,
-		futureBlocks:   futureBlocks,
-		engine:         engine,
-		vmConfig:       vmConfig,
+		triesInMemory:         cacheConfig.TriesInMemory,
+		quit:                  make(chan struct{}),
+		shouldPreserve:        shouldPreserve,
+		bodyCache:             bodyCache,
+		bodyRLPCache:          bodyRLPCache,
+		receiptsCache:         receiptsCache,
+		blockCache:            blockCache,
+		diffLayerCache:        diffLayerCache,
+		diffLayerRLPCache:     diffLayerRLPCache,
+		txLookupCache:         txLookupCache,
+		futureBlocks:          futureBlocks,
+		engine:                engine,
+		vmConfig:              vmConfig,
+		diffQueue:             prque.New(nil),
+		diffQueueBuffer:       make(chan *types.DiffLayer),
+		blockHashToDiffLayers: make(map[common.Hash]map[common.Hash]*types.DiffLayer),
+		diffHashToBlockHash:   make(map[common.Hash]common.Hash),
+		diffHashToPeers:       make(map[common.Hash]map[string]struct{}),
+		diffNumToBlockHashes:  make(map[uint64]map[common.Hash]struct{}),
+		diffPeersToDiffHashes: make(map[string]map[common.Hash]struct{}),
 	}
 	bc.validator = NewBlockValidator(chainConfig, bc, engine)
 	bc.processor = NewStateProcessor(chainConfig, bc, engine)
@@ -375,6 +419,10 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
 		}
 		bc.snaps, _ = snapshot.New(bc.db, bc.stateCache.TrieDB(), bc.cacheConfig.SnapshotLimit, int(bc.cacheConfig.TriesInMemory), head.Root(), !bc.cacheConfig.SnapshotWait, true, recover)
 	}
+	// do options before start any routine
+	for _, option := range options {
+		bc = option(bc)
+	}
 	// Take ownership of this particular state
 	go bc.update()
 	if txLookupLimit != nil {
@@ -396,6 +444,12 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
 			triedb.SaveCachePeriodically(bc.cacheConfig.TrieCleanJournal, bc.cacheConfig.TrieCleanRejournal, bc.quit)
 		}()
 	}
+	// Need persist and prune diff layer
+	if bc.db.DiffStore() != nil {
+		go bc.trustedDiffLayerLoop()
+	}
+	go bc.untrustedDiffLayerPruneLoop()
+
 	return bc, nil
 }
 
@@ -404,11 +458,19 @@ func (bc *BlockChain) GetVMConfig() *vm.Config {
 	return &bc.vmConfig
 }
 
-func (bc *BlockChain) CacheReceipts(hash common.Hash, receipts types.Receipts) {
+func (bc *BlockChain) cacheReceipts(hash common.Hash, receipts types.Receipts) {
 	bc.receiptsCache.Add(hash, receipts)
 }
 
-func (bc *BlockChain) CacheBlock(hash common.Hash, block *types.Block) {
+func (bc *BlockChain) cacheDiffLayer(diffLayer *types.DiffLayer) {
+	bc.diffLayerCache.Add(diffLayer.BlockHash, diffLayer)
+	if bc.db.DiffStore() != nil {
+		// push to priority queue before persisting
+		bc.diffQueueBuffer <- diffLayer
+	}
+}
+
+func (bc *BlockChain) cacheBlock(hash common.Hash, block *types.Block) {
 	bc.blockCache.Add(hash, block)
 }
 
@@ -873,6 +935,45 @@ func (bc *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue {
 	return body
 }
 
+// GetDiffLayerRLP retrieves a diff layer in RLP encoding from the cache or database by blockHash
+func (bc *BlockChain) GetDiffLayerRLP(blockHash common.Hash) rlp.RawValue {
+	// Short circuit if the diffLayer's already in the cache, retrieve otherwise
+	if cached, ok := bc.diffLayerRLPCache.Get(blockHash); ok {
+		return cached.(rlp.RawValue)
+	}
+	if cached, ok := bc.diffLayerCache.Get(blockHash); ok {
+		diff := cached.(*types.DiffLayer)
+		bz, err := rlp.EncodeToBytes(diff)
+		if err != nil {
+			return nil
+		}
+		bc.diffLayerRLPCache.Add(blockHash, rlp.RawValue(bz))
+		return bz
+	}
+
+	// fallback to untrusted sources.
+	diff := bc.GetUnTrustedDiffLayer(blockHash, "")
+	if diff != nil {
+		bz, err := rlp.EncodeToBytes(diff)
+		if err != nil {
+			return nil
+		}
+		// No need to cache untrusted data
+		return bz
+	}
+
+	// fallback to disk
+	diffStore := bc.db.DiffStore()
+	if diffStore == nil {
+		return nil
+	}
+	rawData := rawdb.ReadDiffLayerRLP(diffStore, blockHash)
+	if len(rawData) != 0 {
+		bc.diffLayerRLPCache.Add(blockHash, rawData)
+	}
+	return rawData
+}
+
 // HasBlock checks if a block is fully present in the database or not.
 func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool {
 	if bc.blockCache.Contains(hash) {
@@ -1506,10 +1607,19 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
 		wg.Done()
 	}()
 	// Commit all cached state changes into underlying memory database.
-	root, err := state.Commit(bc.chainConfig.IsEIP158(block.Number()))
+	root, diffLayer, err := state.Commit(bc.chainConfig.IsEIP158(block.Number()))
 	if err != nil {
 		return NonStatTy, err
 	}
+
+	// Ensure no empty block body
+	if diffLayer != nil && block.Header().TxHash != types.EmptyRootHash {
+		// Filling necessary field
+		diffLayer.Receipts = receipts
+		diffLayer.BlockHash = block.Hash()
+		diffLayer.Number = block.NumberU64()
+		bc.cacheDiffLayer(diffLayer)
+	}
 	triedb := bc.stateCache.TrieDB()
 
 	// If we're running an archive node, always flush
@@ -1885,18 +1995,15 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
 		}
 		// Enable prefetching to pull in trie node paths while processing transactions
 		statedb.StartPrefetcher("chain")
-		activeState = statedb
-		statedb.TryPreload(block, signer)
 
 		//Process block using the parent state as reference point
 		substart := time.Now()
-		receipts, logs, usedGas, err := bc.processor.Process(block, statedb, bc.vmConfig)
+		statedb, receipts, logs, usedGas, err := bc.processor.Process(block, statedb, bc.vmConfig)
+		activeState = statedb
 		if err != nil {
 			bc.reportBlock(block, receipts, err)
 			return it.index, err
 		}
-		bc.CacheReceipts(block.Hash(), receipts)
-		bc.CacheBlock(block.Hash(), block)
 		// Update the metrics touched during block processing
 		accountReadTimer.Update(statedb.AccountReads)                 // Account reads are complete, we can mark them
 		storageReadTimer.Update(statedb.StorageReads)                 // Storage reads are complete, we can mark them
@@ -1904,18 +2011,20 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
 		storageUpdateTimer.Update(statedb.StorageUpdates)             // Storage updates are complete, we can mark them
 		snapshotAccountReadTimer.Update(statedb.SnapshotAccountReads) // Account reads are complete, we can mark them
 		snapshotStorageReadTimer.Update(statedb.SnapshotStorageReads) // Storage reads are complete, we can mark them
-		trieproc := statedb.SnapshotAccountReads + statedb.AccountReads + statedb.AccountUpdates
-		trieproc += statedb.SnapshotStorageReads + statedb.StorageReads + statedb.StorageUpdates
 
 		blockExecutionTimer.Update(time.Since(substart))
 
 		// Validate the state using the default validator
 		substart = time.Now()
-		if err := bc.validator.ValidateState(block, statedb, receipts, usedGas); err != nil {
-			bc.reportBlock(block, receipts, err)
-			log.Error("validate state failed", "error", err)
-			return it.index, err
+		if !statedb.IsLightProcessed() {
+			if err := bc.validator.ValidateState(block, statedb, receipts, usedGas); err != nil {
+				log.Error("validate state failed", "error", err)
+				bc.reportBlock(block, receipts, err)
+				return it.index, err
+			}
 		}
+		bc.cacheReceipts(block.Hash(), receipts)
+		bc.cacheBlock(block.Hash(), block)
 		proctime := time.Since(start)
 
 		// Update the metrics touched during block validation
@@ -2292,6 +2401,279 @@ func (bc *BlockChain) update() {
 	}
 }
 
+func (bc *BlockChain) trustedDiffLayerLoop() {
+	recheck := time.Tick(diffLayerFreezerRecheckInterval)
+	bc.wg.Add(1)
+	defer bc.wg.Done()
+	for {
+		select {
+		case diff := <-bc.diffQueueBuffer:
+			bc.diffQueue.Push(diff, -(int64(diff.Number)))
+		case <-bc.quit:
+			// Persist all diffLayers when shutdown, it will introduce redundant storage, but it is acceptable.
+			// If the client been ungracefully shutdown, it will missing all cached diff layers, it is acceptable as well.
+			var batch ethdb.Batch
+			for !bc.diffQueue.Empty() {
+				diff, _ := bc.diffQueue.Pop()
+				diffLayer := diff.(*types.DiffLayer)
+				if batch == nil {
+					batch = bc.db.DiffStore().NewBatch()
+				}
+				rawdb.WriteDiffLayer(batch, diffLayer.BlockHash, diffLayer)
+				if batch.ValueSize() > ethdb.IdealBatchSize {
+					if err := batch.Write(); err != nil {
+						log.Error("Failed to write diff layer", "err", err)
+						return
+					}
+					batch.Reset()
+				}
+			}
+			if batch != nil {
+				// flush data
+				if err := batch.Write(); err != nil {
+					log.Error("Failed to write diff layer", "err", err)
+					return
+				}
+				batch.Reset()
+			}
+			return
+		case <-recheck:
+			currentHeight := bc.CurrentBlock().NumberU64()
+			var batch ethdb.Batch
+			for !bc.diffQueue.Empty() {
+				diff, prio := bc.diffQueue.Pop()
+				diffLayer := diff.(*types.DiffLayer)
+
+				// if the block old enough
+				if int64(currentHeight)+prio >= int64(bc.triesInMemory) {
+					canonicalHash := bc.GetCanonicalHash(uint64(-prio))
+					// on the canonical chain
+					if canonicalHash == diffLayer.BlockHash {
+						if batch == nil {
+							batch = bc.db.DiffStore().NewBatch()
+						}
+						rawdb.WriteDiffLayer(batch, diffLayer.BlockHash, diffLayer)
+						staleHash := bc.GetCanonicalHash(uint64(-prio) - bc.diffLayerFreezerBlockLimit)
+						rawdb.DeleteDiffLayer(batch, staleHash)
+					}
+				} else {
+					bc.diffQueue.Push(diffLayer, prio)
+					break
+				}
+				if batch != nil && batch.ValueSize() > ethdb.IdealBatchSize {
+					if err := batch.Write(); err != nil {
+						panic(fmt.Sprintf("Failed to write diff layer, error %v", err))
+					}
+					batch.Reset()
+				}
+			}
+			if batch != nil {
+				if err := batch.Write(); err != nil {
+					panic(fmt.Sprintf("Failed to write diff layer, error %v", err))
+				}
+				batch.Reset()
+			}
+		}
+	}
+}
+
+func (bc *BlockChain) GetUnTrustedDiffLayer(blockHash common.Hash, pid string) *types.DiffLayer {
+	bc.diffMux.RLock()
+	defer bc.diffMux.RUnlock()
+	if diffs, exist := bc.blockHashToDiffLayers[blockHash]; exist && len(diffs) != 0 {
+		if len(diffs) == 1 {
+			// return the only one diff layer
+			for _, diff := range diffs {
+				return diff
+			}
+		}
+		// pick the one from exact same peer if we know where the block comes from
+		if pid != "" {
+			if diffHashes, exist := bc.diffPeersToDiffHashes[pid]; exist {
+				for diff := range diffs {
+					if _, overlap := diffHashes[diff]; overlap {
+						return bc.blockHashToDiffLayers[blockHash][diff]
+					}
+				}
+			}
+		}
+		// Do not find overlap, do random pick
+		for _, diff := range diffs {
+			return diff
+		}
+	}
+	return nil
+}
+
+func (bc *BlockChain) removeDiffLayers(diffHash common.Hash) {
+	bc.diffMux.Lock()
+	defer bc.diffMux.Unlock()
+
+	// Untrusted peers
+	pids := bc.diffHashToPeers[diffHash]
+	invalidDiffHashes := make(map[common.Hash]struct{})
+	for pid := range pids {
+		invaliDiffHashesPeer := bc.diffPeersToDiffHashes[pid]
+		for invaliDiffHash := range invaliDiffHashesPeer {
+			invalidDiffHashes[invaliDiffHash] = struct{}{}
+		}
+		delete(bc.diffPeersToDiffHashes, pid)
+	}
+	for invalidDiffHash := range invalidDiffHashes {
+		delete(bc.diffHashToPeers, invalidDiffHash)
+		affectedBlockHash := bc.diffHashToBlockHash[invalidDiffHash]
+		if diffs, exist := bc.blockHashToDiffLayers[affectedBlockHash]; exist {
+			delete(diffs, invalidDiffHash)
+			if len(diffs) == 0 {
+				delete(bc.blockHashToDiffLayers, affectedBlockHash)
+			}
+		}
+		delete(bc.diffHashToBlockHash, invalidDiffHash)
+	}
+}
+
+func (bc *BlockChain) RemoveDiffPeer(pid string) {
+	bc.diffMux.Lock()
+	defer bc.diffMux.Unlock()
+	if invaliDiffHashes := bc.diffPeersToDiffHashes[pid]; invaliDiffHashes != nil {
+		for invalidDiffHash := range invaliDiffHashes {
+			lastDiffHash := false
+			if peers, ok := bc.diffHashToPeers[invalidDiffHash]; ok {
+				delete(peers, pid)
+				if len(peers) == 0 {
+					lastDiffHash = true
+					delete(bc.diffHashToPeers, invalidDiffHash)
+				}
+			}
+			if lastDiffHash {
+				affectedBlockHash := bc.diffHashToBlockHash[invalidDiffHash]
+				if diffs, exist := bc.blockHashToDiffLayers[affectedBlockHash]; exist {
+					delete(diffs, invalidDiffHash)
+					if len(diffs) == 0 {
+						delete(bc.blockHashToDiffLayers, affectedBlockHash)
+					}
+				}
+				delete(bc.diffHashToBlockHash, invalidDiffHash)
+			}
+		}
+		delete(bc.diffPeersToDiffHashes, pid)
+	}
+}
+
+func (bc *BlockChain) untrustedDiffLayerPruneLoop() {
+	recheck := time.NewTicker(diffLayerPruneRecheckInterval)
+	bc.wg.Add(1)
+	defer func() {
+		bc.wg.Done()
+		recheck.Stop()
+	}()
+	for {
+		select {
+		case <-bc.quit:
+			return
+		case <-recheck.C:
+			bc.pruneDiffLayer()
+		}
+	}
+}
+
+func (bc *BlockChain) pruneDiffLayer() {
+	currentHeight := bc.CurrentBlock().NumberU64()
+	bc.diffMux.Lock()
+	defer bc.diffMux.Unlock()
+	sortNumbers := make([]uint64, 0, len(bc.diffNumToBlockHashes))
+	for number := range bc.diffNumToBlockHashes {
+		sortNumbers = append(sortNumbers, number)
+	}
+	sort.Slice(sortNumbers, func(i, j int) bool {
+		return sortNumbers[i] <= sortNumbers[j]
+	})
+	staleBlockHashes := make(map[common.Hash]struct{})
+	for _, number := range sortNumbers {
+		if number >= currentHeight-maxDiffForkDist {
+			break
+		}
+		affectedHashes := bc.diffNumToBlockHashes[number]
+		if affectedHashes != nil {
+			for affectedHash := range affectedHashes {
+				staleBlockHashes[affectedHash] = struct{}{}
+			}
+			delete(bc.diffNumToBlockHashes, number)
+		}
+	}
+	staleDiffHashes := make(map[common.Hash]struct{})
+	for blockHash := range staleBlockHashes {
+		if diffHashes, exist := bc.blockHashToDiffLayers[blockHash]; exist {
+			for diffHash := range diffHashes {
+				staleDiffHashes[diffHash] = struct{}{}
+				delete(bc.diffHashToBlockHash, diffHash)
+				delete(bc.diffHashToPeers, diffHash)
+			}
+		}
+		delete(bc.blockHashToDiffLayers, blockHash)
+	}
+	for diffHash := range staleDiffHashes {
+		for p, diffHashes := range bc.diffPeersToDiffHashes {
+			delete(diffHashes, diffHash)
+			if len(diffHashes) == 0 {
+				delete(bc.diffPeersToDiffHashes, p)
+			}
+		}
+	}
+}
+
+// Process received diff layers
+func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string, fulfilled bool) error {
+	// Basic check
+	currentHeight := bc.CurrentBlock().NumberU64()
+	if diffLayer.Number > currentHeight && diffLayer.Number-currentHeight > maxDiffQueueDist {
+		log.Error("diff layers too new from current", "pid", pid)
+		return nil
+	}
+	if diffLayer.Number < currentHeight && currentHeight-diffLayer.Number > maxDiffForkDist {
+		log.Error("diff layers too old from current", "pid", pid)
+		return nil
+	}
+
+	bc.diffMux.Lock()
+	defer bc.diffMux.Unlock()
+
+	if !fulfilled && len(bc.diffPeersToDiffHashes[pid]) > maxDiffLimitForBroadcast {
+		log.Error("too many accumulated diffLayers", "pid", pid)
+		return nil
+	}
+
+	if len(bc.diffPeersToDiffHashes[pid]) > maxDiffLimit {
+		log.Error("too many accumulated diffLayers", "pid", pid)
+		return nil
+	}
+	if _, exist := bc.diffPeersToDiffHashes[pid]; exist {
+		if _, alreadyHas := bc.diffPeersToDiffHashes[pid][diffLayer.DiffHash]; alreadyHas {
+			return nil
+		}
+	} else {
+		bc.diffPeersToDiffHashes[pid] = make(map[common.Hash]struct{})
+	}
+	bc.diffPeersToDiffHashes[pid][diffLayer.DiffHash] = struct{}{}
+	if _, exist := bc.diffNumToBlockHashes[diffLayer.Number]; !exist {
+		bc.diffNumToBlockHashes[diffLayer.Number] = make(map[common.Hash]struct{})
+	}
+	bc.diffNumToBlockHashes[diffLayer.Number][diffLayer.BlockHash] = struct{}{}
+
+	if _, exist := bc.diffHashToPeers[diffLayer.DiffHash]; !exist {
+		bc.diffHashToPeers[diffLayer.DiffHash] = make(map[string]struct{})
+	}
+	bc.diffHashToPeers[diffLayer.DiffHash][pid] = struct{}{}
+
+	if _, exist := bc.blockHashToDiffLayers[diffLayer.BlockHash]; !exist {
+		bc.blockHashToDiffLayers[diffLayer.BlockHash] = make(map[common.Hash]*types.DiffLayer)
+	}
+	bc.blockHashToDiffLayers[diffLayer.BlockHash][diffLayer.DiffHash] = diffLayer
+	bc.diffHashToBlockHash[diffLayer.DiffHash] = diffLayer.BlockHash
+
+	return nil
+}
+
 // maintainTxIndex is responsible for the construction and deletion of the
 // transaction index.
 //
@@ -2541,3 +2923,16 @@ func (bc *BlockChain) SubscribeLogsEvent(ch chan<- []*types.Log) event.Subscript
 func (bc *BlockChain) SubscribeBlockProcessingEvent(ch chan<- bool) event.Subscription {
 	return bc.scope.Track(bc.blockProcFeed.Subscribe(ch))
 }
+
+// Options
+func EnableLightProcessor(bc *BlockChain) *BlockChain {
+	bc.processor = NewLightStateProcessor(bc.Config(), bc, bc.engine)
+	return bc
+}
+
+func EnablePersistDiff(limit uint64) BlockChainOption {
+	return func(chain *BlockChain) *BlockChain {
+		chain.diffLayerFreezerBlockLimit = limit
+		return chain
+	}
+}

+ 263 - 0
core/blockchain_diff_test.go

@@ -0,0 +1,263 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// Tests that abnormal program termination (i.e.crash) and restart doesn't leave
+// the database in some strange state with gaps in the chain, nor with block data
+// dangling in the future.
+
+package core
+
+import (
+	"math/big"
+	"testing"
+	"time"
+
+	"golang.org/x/crypto/sha3"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/consensus/ethash"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/core/state/snapshot"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/core/vm"
+	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/ethdb"
+	"github.com/ethereum/go-ethereum/ethdb/memorydb"
+	"github.com/ethereum/go-ethereum/params"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+var (
+	// testKey is a private key to use for funding a tester account.
+	testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+	// testAddr is the Ethereum address of the tester account.
+	testAddr = crypto.PubkeyToAddress(testKey.PublicKey)
+)
+
+// testBackend is a mock implementation of the live Ethereum message handler. Its
+// purpose is to allow testing the request/reply workflows and wire serialization
+// in the `eth` protocol without actually doing any data processing.
+type testBackend struct {
+	db    ethdb.Database
+	chain *BlockChain
+}
+
+// newTestBackend creates an empty chain and wraps it into a mock backend.
+func newTestBackend(blocks int, light bool) *testBackend {
+	return newTestBackendWithGenerator(blocks, light)
+}
+
+// newTestBackend creates a chain with a number of explicitly defined blocks and
+// wraps it into a mock backend.
+func newTestBackendWithGenerator(blocks int, lightProcess bool) *testBackend {
+	signer := types.HomesteadSigner{}
+	// Create a database pre-initialize with a genesis block
+	db := rawdb.NewMemoryDatabase()
+	db.SetDiffStore(memorydb.New())
+	(&Genesis{
+		Config: params.TestChainConfig,
+		Alloc:  GenesisAlloc{testAddr: {Balance: big.NewInt(100000000000000000)}},
+	}).MustCommit(db)
+
+	chain, _ := NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil, EnablePersistDiff(860000))
+	generator := func(i int, block *BlockGen) {
+		// The chain maker doesn't have access to a chain, so the difficulty will be
+		// lets unset (nil). Set it here to the correct value.
+		block.SetCoinbase(testAddr)
+
+		// We want to simulate an empty middle block, having the same state as the
+		// first one. The last is needs a state change again to force a reorg.
+		tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), common.Address{0x01}, big.NewInt(1), params.TxGas, big.NewInt(1), nil), signer, testKey)
+		if err != nil {
+			panic(err)
+		}
+		block.AddTxWithChain(chain, tx)
+	}
+	bs, _ := GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, generator)
+	if _, err := chain.InsertChain(bs); err != nil {
+		panic(err)
+	}
+	if lightProcess {
+		EnableLightProcessor(chain)
+	}
+
+	return &testBackend{
+		db:    db,
+		chain: chain,
+	}
+}
+
+// close tears down the transaction pool and chain behind the mock backend.
+func (b *testBackend) close() {
+	b.chain.Stop()
+}
+
+func (b *testBackend) Chain() *BlockChain { return b.chain }
+
+func rawDataToDiffLayer(data rlp.RawValue) (*types.DiffLayer, error) {
+	var diff types.DiffLayer
+	hasher := sha3.NewLegacyKeccak256()
+	err := rlp.DecodeBytes(data, &diff)
+	if err != nil {
+		return nil, err
+	}
+	hasher.Write(data)
+	var diffHash common.Hash
+	hasher.Sum(diffHash[:0])
+	diff.DiffHash = diffHash
+	hasher.Reset()
+	return &diff, nil
+}
+
+func TestProcessDiffLayer(t *testing.T) {
+	t.Parallel()
+
+	blockNum := 128
+	fullBackend := newTestBackend(blockNum, false)
+	falseDiff := 5
+	defer fullBackend.close()
+
+	lightBackend := newTestBackend(0, true)
+	defer lightBackend.close()
+	for i := 1; i <= blockNum-falseDiff; i++ {
+		block := fullBackend.chain.GetBlockByNumber(uint64(i))
+		if block == nil {
+			t.Fatal("block should not be nil")
+		}
+		blockHash := block.Hash()
+		rawDiff := fullBackend.chain.GetDiffLayerRLP(blockHash)
+		diff, err := rawDataToDiffLayer(rawDiff)
+		if err != nil {
+			t.Errorf("failed to decode rawdata %v", err)
+		}
+		lightBackend.Chain().HandleDiffLayer(diff, "testpid", true)
+		_, err = lightBackend.chain.insertChain([]*types.Block{block}, true)
+		if err != nil {
+			t.Errorf("failed to insert block %v", err)
+		}
+	}
+	currentBlock := lightBackend.chain.CurrentBlock()
+	nextBlock := fullBackend.chain.GetBlockByNumber(currentBlock.NumberU64() + 1)
+	rawDiff := fullBackend.chain.GetDiffLayerRLP(nextBlock.Hash())
+	diff, _ := rawDataToDiffLayer(rawDiff)
+	latestAccount, _ := snapshot.FullAccount(diff.Accounts[0].Blob)
+	latestAccount.Balance = big.NewInt(0)
+	bz, _ := rlp.EncodeToBytes(&latestAccount)
+	diff.Accounts[0].Blob = bz
+
+	lightBackend.Chain().HandleDiffLayer(diff, "testpid", true)
+
+	_, err := lightBackend.chain.insertChain([]*types.Block{nextBlock}, true)
+	if err != nil {
+		t.Errorf("failed to process block %v", err)
+	}
+
+	// the diff cache should be cleared
+	if len(lightBackend.chain.diffPeersToDiffHashes) != 0 {
+		t.Errorf("the size of diffPeersToDiffHashes should be 0, but get %d", len(lightBackend.chain.diffPeersToDiffHashes))
+	}
+	if len(lightBackend.chain.diffHashToPeers) != 0 {
+		t.Errorf("the size of diffHashToPeers should be 0, but get %d", len(lightBackend.chain.diffHashToPeers))
+	}
+	if len(lightBackend.chain.diffHashToBlockHash) != 0 {
+		t.Errorf("the size of diffHashToBlockHash should be 0, but get %d", len(lightBackend.chain.diffHashToBlockHash))
+	}
+	if len(lightBackend.chain.blockHashToDiffLayers) != 0 {
+		t.Errorf("the size of blockHashToDiffLayers should be 0, but get %d", len(lightBackend.chain.blockHashToDiffLayers))
+	}
+}
+
+func TestFreezeDiffLayer(t *testing.T) {
+	t.Parallel()
+
+	blockNum := 1024
+	fullBackend := newTestBackend(blockNum, true)
+	defer fullBackend.close()
+	if fullBackend.chain.diffQueue.Size() != blockNum {
+		t.Errorf("size of diff queue is wrong, expected: %d, get: %d", blockNum, fullBackend.chain.diffQueue.Size())
+	}
+	time.Sleep(diffLayerFreezerRecheckInterval + 1*time.Second)
+	if fullBackend.chain.diffQueue.Size() != int(fullBackend.chain.triesInMemory) {
+		t.Errorf("size of diff queue is wrong, expected: %d, get: %d", blockNum, fullBackend.chain.diffQueue.Size())
+	}
+
+	block := fullBackend.chain.GetBlockByNumber(uint64(blockNum / 2))
+	diffStore := fullBackend.chain.db.DiffStore()
+	rawData := rawdb.ReadDiffLayerRLP(diffStore, block.Hash())
+	if len(rawData) == 0 {
+		t.Error("do not find diff layer in db")
+	}
+}
+
+func TestPruneDiffLayer(t *testing.T) {
+	t.Parallel()
+
+	blockNum := 1024
+	fullBackend := newTestBackend(blockNum, true)
+	defer fullBackend.close()
+
+	anotherFullBackend := newTestBackend(2*blockNum, true)
+	defer anotherFullBackend.close()
+
+	for num := uint64(1); num < uint64(blockNum); num++ {
+		header := fullBackend.chain.GetHeaderByNumber(num)
+		rawDiff := fullBackend.chain.GetDiffLayerRLP(header.Hash())
+		diff, _ := rawDataToDiffLayer(rawDiff)
+		fullBackend.Chain().HandleDiffLayer(diff, "testpid1", true)
+		fullBackend.Chain().HandleDiffLayer(diff, "testpid2", true)
+
+	}
+	fullBackend.chain.pruneDiffLayer()
+	if len(fullBackend.chain.diffNumToBlockHashes) != maxDiffForkDist {
+		t.Error("unexpected size of diffNumToBlockHashes")
+	}
+	if len(fullBackend.chain.diffPeersToDiffHashes) != 2 {
+		t.Error("unexpected size of diffPeersToDiffHashes")
+	}
+	if len(fullBackend.chain.blockHashToDiffLayers) != maxDiffForkDist {
+		t.Error("unexpected size of diffNumToBlockHashes")
+	}
+	if len(fullBackend.chain.diffHashToBlockHash) != maxDiffForkDist {
+		t.Error("unexpected size of diffHashToBlockHash")
+	}
+	if len(fullBackend.chain.diffHashToPeers) != maxDiffForkDist {
+		t.Error("unexpected size of diffHashToPeers")
+	}
+
+	blocks := make([]*types.Block, 0, blockNum)
+	for i := blockNum + 1; i <= 2*blockNum; i++ {
+		b := anotherFullBackend.chain.GetBlockByNumber(uint64(i))
+		blocks = append(blocks, b)
+	}
+	fullBackend.chain.insertChain(blocks, true)
+	fullBackend.chain.pruneDiffLayer()
+	if len(fullBackend.chain.diffNumToBlockHashes) != 0 {
+		t.Error("unexpected size of diffNumToBlockHashes")
+	}
+	if len(fullBackend.chain.diffPeersToDiffHashes) != 0 {
+		t.Error("unexpected size of diffPeersToDiffHashes")
+	}
+	if len(fullBackend.chain.blockHashToDiffLayers) != 0 {
+		t.Error("unexpected size of diffNumToBlockHashes")
+	}
+	if len(fullBackend.chain.diffHashToBlockHash) != 0 {
+		t.Error("unexpected size of diffHashToBlockHash")
+	}
+	if len(fullBackend.chain.diffHashToPeers) != 0 {
+		t.Error("unexpected size of diffHashToPeers")
+	}
+
+}

+ 3 - 3
core/blockchain_test.go

@@ -151,7 +151,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
 		if err != nil {
 			return err
 		}
-		receipts, _, usedGas, err := blockchain.processor.Process(block, statedb, vm.Config{})
+		statedb, receipts, _, usedGas, err := blockchain.processor.Process(block, statedb, vm.Config{})
 		if err != nil {
 			blockchain.reportBlock(block, receipts, err)
 			return err
@@ -1769,7 +1769,7 @@ func testSideImport(t *testing.T, numCanonBlocksInSidechain, blocksBetweenCommon
 	}
 
 	lastPrunedIndex := len(blocks) - TestTriesInMemory - 1
-	lastPrunedBlock := blocks[lastPrunedIndex]
+	lastPrunedBlock := blocks[lastPrunedIndex-1]
 	firstNonPrunedBlock := blocks[len(blocks)-TestTriesInMemory]
 
 	// Verify pruning of lastPrunedBlock
@@ -2420,7 +2420,7 @@ func TestSideImportPrunedBlocks(t *testing.T) {
 	}
 
 	lastPrunedIndex := len(blocks) - TestTriesInMemory - 1
-	lastPrunedBlock := blocks[lastPrunedIndex]
+	lastPrunedBlock := blocks[lastPrunedIndex-1]
 
 	// Verify pruning of lastPrunedBlock
 	if chain.HasBlockAndState(lastPrunedBlock.Hash(), lastPrunedBlock.NumberU64()) {

+ 1 - 1
core/chain_makers.go

@@ -223,7 +223,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
 			block, _, _ := b.engine.FinalizeAndAssemble(chainreader, b.header, statedb, b.txs, b.uncles, b.receipts)
 
 			// Write state changes to db
-			root, err := statedb.Commit(config.IsEIP158(b.header.Number))
+			root, _, err := statedb.Commit(config.IsEIP158(b.header.Number))
 			if err != nil {
 				panic(fmt.Sprintf("state write error: %v", err))
 			}

+ 38 - 0
core/rawdb/accessors_chain.go

@@ -447,6 +447,44 @@ func WriteBody(db ethdb.KeyValueWriter, hash common.Hash, number uint64, body *t
 	WriteBodyRLP(db, hash, number, data)
 }
 
+func WriteDiffLayer(db ethdb.KeyValueWriter, hash common.Hash, layer *types.DiffLayer) {
+	data, err := rlp.EncodeToBytes(layer)
+	if err != nil {
+		log.Crit("Failed to RLP encode diff layer", "err", err)
+	}
+	WriteDiffLayerRLP(db, hash, data)
+}
+
+func WriteDiffLayerRLP(db ethdb.KeyValueWriter, blockHash common.Hash, rlp rlp.RawValue) {
+	if err := db.Put(diffLayerKey(blockHash), rlp); err != nil {
+		log.Crit("Failed to store diff layer", "err", err)
+	}
+}
+
+func ReadDiffLayer(db ethdb.KeyValueReader, blockHash common.Hash) *types.DiffLayer {
+	data := ReadDiffLayerRLP(db, blockHash)
+	if len(data) == 0 {
+		return nil
+	}
+	diff := new(types.DiffLayer)
+	if err := rlp.Decode(bytes.NewReader(data), diff); err != nil {
+		log.Error("Invalid diff layer RLP", "hash", blockHash, "err", err)
+		return nil
+	}
+	return diff
+}
+
+func ReadDiffLayerRLP(db ethdb.KeyValueReader, blockHash common.Hash) rlp.RawValue {
+	data, _ := db.Get(diffLayerKey(blockHash))
+	return data
+}
+
+func DeleteDiffLayer(db ethdb.KeyValueWriter, blockHash common.Hash) {
+	if err := db.Delete(diffLayerKey(blockHash)); err != nil {
+		log.Crit("Failed to delete diffLayer", "err", err)
+	}
+}
+
 // DeleteBody removes all block body data associated with a hash.
 func DeleteBody(db ethdb.KeyValueWriter, hash common.Hash, number uint64) {
 	if err := db.Delete(blockBodyKey(number, hash)); err != nil {

+ 26 - 0
core/rawdb/database.go

@@ -36,6 +36,7 @@ import (
 type freezerdb struct {
 	ethdb.KeyValueStore
 	ethdb.AncientStore
+	diffStore ethdb.KeyValueStore
 }
 
 // Close implements io.Closer, closing both the fast key-value store as well as
@@ -48,12 +49,28 @@ func (frdb *freezerdb) Close() error {
 	if err := frdb.KeyValueStore.Close(); err != nil {
 		errs = append(errs, err)
 	}
+	if frdb.diffStore != nil {
+		if err := frdb.diffStore.Close(); err != nil {
+			errs = append(errs, err)
+		}
+	}
 	if len(errs) != 0 {
 		return fmt.Errorf("%v", errs)
 	}
 	return nil
 }
 
+func (frdb *freezerdb) DiffStore() ethdb.KeyValueStore {
+	return frdb.diffStore
+}
+
+func (frdb *freezerdb) SetDiffStore(diff ethdb.KeyValueStore) {
+	if frdb.diffStore != nil {
+		frdb.diffStore.Close()
+	}
+	frdb.diffStore = diff
+}
+
 // Freeze is a helper method used for external testing to trigger and block until
 // a freeze cycle completes, without having to sleep for a minute to trigger the
 // automatic background run.
@@ -77,6 +94,7 @@ func (frdb *freezerdb) Freeze(threshold uint64) error {
 // nofreezedb is a database wrapper that disables freezer data retrievals.
 type nofreezedb struct {
 	ethdb.KeyValueStore
+	diffStore ethdb.KeyValueStore
 }
 
 // HasAncient returns an error as we don't have a backing chain freezer.
@@ -114,6 +132,14 @@ func (db *nofreezedb) Sync() error {
 	return errNotSupported
 }
 
+func (db *nofreezedb) DiffStore() ethdb.KeyValueStore {
+	return db.diffStore
+}
+
+func (db *nofreezedb) SetDiffStore(diff ethdb.KeyValueStore) {
+	db.diffStore = diff
+}
+
 // NewDatabase creates a high level database on top of a given key-value data
 // store without a freezer moving immutable chain segments into cold storage.
 func NewDatabase(db ethdb.KeyValueStore) ethdb.Database {

+ 3 - 57
core/rawdb/freezer_table_test.go

@@ -18,13 +18,10 @@ package rawdb
 
 import (
 	"bytes"
-	"encoding/binary"
 	"fmt"
-	"io/ioutil"
 	"math/rand"
 	"os"
 	"path/filepath"
-	"sync"
 	"testing"
 	"time"
 
@@ -528,7 +525,6 @@ func TestOffset(t *testing.T) {
 
 		f.Append(4, getChunk(20, 0xbb))
 		f.Append(5, getChunk(20, 0xaa))
-		f.DumpIndex(0, 100)
 		f.Close()
 	}
 	// Now crop it.
@@ -575,7 +571,6 @@ func TestOffset(t *testing.T) {
 		if err != nil {
 			t.Fatal(err)
 		}
-		f.DumpIndex(0, 100)
 		// It should allow writing item 6
 		f.Append(numDeleted+2, getChunk(20, 0x99))
 
@@ -640,55 +635,6 @@ func TestOffset(t *testing.T) {
 // 1. have data files d0, d1, d2, d3
 // 2. remove d2,d3
 //
-// However, all 'normal' failure modes arising due to failing to sync() or save a file
-// should be handled already, and the case described above can only (?) happen if an
-// external process/user deletes files from the filesystem.
-
-// TestAppendTruncateParallel is a test to check if the Append/truncate operations are
-// racy.
-//
-// The reason why it's not a regular fuzzer, within tests/fuzzers, is that it is dependent
-// on timing rather than 'clever' input -- there's no determinism.
-func TestAppendTruncateParallel(t *testing.T) {
-	dir, err := ioutil.TempDir("", "freezer")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer os.RemoveAll(dir)
-
-	f, err := newCustomTable(dir, "tmp", metrics.NilMeter{}, metrics.NilMeter{}, metrics.NilGauge{}, 8, true)
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	fill := func(mark uint64) []byte {
-		data := make([]byte, 8)
-		binary.LittleEndian.PutUint64(data, mark)
-		return data
-	}
-
-	for i := 0; i < 5000; i++ {
-		f.truncate(0)
-		data0 := fill(0)
-		f.Append(0, data0)
-		data1 := fill(1)
-
-		var wg sync.WaitGroup
-		wg.Add(2)
-		go func() {
-			f.truncate(0)
-			wg.Done()
-		}()
-		go func() {
-			f.Append(1, data1)
-			wg.Done()
-		}()
-		wg.Wait()
-
-		if have, err := f.Retrieve(0); err == nil {
-			if !bytes.Equal(have, data0) {
-				t.Fatalf("have %x want %x", have, data0)
-			}
-		}
-	}
-}
+// However, all 'normal' failure modes arising due to failing to sync() or save a file should be
+// handled already, and the case described above can only (?) happen if an external process/user
+// deletes files from the filesystem.

+ 8 - 0
core/rawdb/schema.go

@@ -90,6 +90,9 @@ var (
 	SnapshotStoragePrefix = []byte("o") // SnapshotStoragePrefix + account hash + storage hash -> storage trie value
 	CodePrefix            = []byte("c") // CodePrefix + code hash -> account code
 
+	// difflayer database
+	diffLayerPrefix = []byte("d") // diffLayerPrefix + hash  -> diffLayer
+
 	preimagePrefix = []byte("secure-key-")      // preimagePrefix + hash -> preimage
 	configPrefix   = []byte("ethereum-config-") // config prefix for the db
 
@@ -177,6 +180,11 @@ func blockReceiptsKey(number uint64, hash common.Hash) []byte {
 	return append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...)
 }
 
+// diffLayerKey = diffLayerKeyPrefix + hash
+func diffLayerKey(hash common.Hash) []byte {
+	return append(append(diffLayerPrefix, hash.Bytes()...))
+}
+
 // txLookupKey = txLookupPrefix + hash
 func txLookupKey(hash common.Hash) []byte {
 	return append(txLookupPrefix, hash.Bytes()...)

+ 8 - 0
core/rawdb/table.go

@@ -159,6 +159,14 @@ func (t *table) NewBatch() ethdb.Batch {
 	return &tableBatch{t.db.NewBatch(), t.prefix}
 }
 
+func (t *table) DiffStore() ethdb.KeyValueStore {
+	return nil
+}
+
+func (t *table) SetDiffStore(diff ethdb.KeyValueStore) {
+	panic("not implement")
+}
+
 // tableBatch is a wrapper around a database batch that prefixes each key access
 // with a pre-configured string.
 type tableBatch struct {

+ 0 - 1
core/state/database.go

@@ -243,7 +243,6 @@ func (db *cachingDB) CacheStorage(addrHash common.Hash, root common.Hash, t Trie
 		triesArray := [3]*triePair{{root: root, trie: tr.ResetCopy()}, nil, nil}
 		db.storageTrieCache.Add(addrHash, triesArray)
 	}
-	return
 }
 
 func (db *cachingDB) Purge() {

+ 1 - 1
core/state/journal.go

@@ -153,7 +153,7 @@ func (ch createObjectChange) dirtied() *common.Address {
 func (ch resetObjectChange) revert(s *StateDB) {
 	s.SetStateObject(ch.prev)
 	if !ch.prevdestruct && s.snap != nil {
-		delete(s.snapDestructs, ch.prev.addrHash)
+		delete(s.snapDestructs, ch.prev.address)
 	}
 }
 

+ 4 - 4
core/state/snapshot/disklayer_test.go

@@ -121,7 +121,7 @@ func TestDiskMerge(t *testing.T) {
 	base.Storage(conNukeCache, conNukeCacheSlot)
 
 	// Modify or delete some accounts, flatten everything onto disk
-	if err := snaps.Update(diffRoot, baseRoot, map[common.Hash]struct{}{
+	if err := snaps.update(diffRoot, baseRoot, map[common.Hash]struct{}{
 		accDelNoCache:  {},
 		accDelCache:    {},
 		conNukeNoCache: {},
@@ -344,7 +344,7 @@ func TestDiskPartialMerge(t *testing.T) {
 		assertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:])
 
 		// Modify or delete some accounts, flatten everything onto disk
-		if err := snaps.Update(diffRoot, baseRoot, map[common.Hash]struct{}{
+		if err := snaps.update(diffRoot, baseRoot, map[common.Hash]struct{}{
 			accDelNoCache:  {},
 			accDelCache:    {},
 			conNukeNoCache: {},
@@ -466,7 +466,7 @@ func TestDiskGeneratorPersistence(t *testing.T) {
 		},
 	}
 	// Modify or delete some accounts, flatten everything onto disk
-	if err := snaps.Update(diffRoot, baseRoot, nil, map[common.Hash][]byte{
+	if err := snaps.update(diffRoot, baseRoot, nil, map[common.Hash][]byte{
 		accTwo: accTwo[:],
 	}, nil); err != nil {
 		t.Fatalf("failed to update snapshot tree: %v", err)
@@ -484,7 +484,7 @@ func TestDiskGeneratorPersistence(t *testing.T) {
 	}
 	// Test scenario 2, the disk layer is fully generated
 	// Modify or delete some accounts, flatten everything onto disk
-	if err := snaps.Update(diffTwoRoot, diffRoot, nil, map[common.Hash][]byte{
+	if err := snaps.update(diffTwoRoot, diffRoot, nil, map[common.Hash][]byte{
 		accThree: accThree.Bytes(),
 	}, map[common.Hash]map[common.Hash][]byte{
 		accThree: {accThreeSlot: accThreeSlot.Bytes()},

+ 43 - 43
core/state/snapshot/iterator_test.go

@@ -221,13 +221,13 @@ func TestAccountIteratorTraversal(t *testing.T) {
 		},
 	}
 	// Stack three diff layers on top with various overlaps
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
 		randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
 
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
 		randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
 
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
 		randomAccountSet("0xcc", "0xf0", "0xff"), nil)
 
 	// Verify the single and multi-layer iterators
@@ -268,13 +268,13 @@ func TestStorageIteratorTraversal(t *testing.T) {
 		},
 	}
 	// Stack three diff layers on top with various overlaps
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x02", "0x03"}}, nil))
 
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x04", "0x05", "0x06"}}, nil))
 
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x02", "0x03"}}, nil))
 
 	// Verify the single and multi-layer iterators
@@ -353,14 +353,14 @@ func TestAccountIteratorTraversalValues(t *testing.T) {
 		}
 	}
 	// Assemble a stack of snapshots from the account layers
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, a, nil)
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, b, nil)
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, c, nil)
-	snaps.Update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil, d, nil)
-	snaps.Update(common.HexToHash("0x06"), common.HexToHash("0x05"), nil, e, nil)
-	snaps.Update(common.HexToHash("0x07"), common.HexToHash("0x06"), nil, f, nil)
-	snaps.Update(common.HexToHash("0x08"), common.HexToHash("0x07"), nil, g, nil)
-	snaps.Update(common.HexToHash("0x09"), common.HexToHash("0x08"), nil, h, nil)
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, a, nil)
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, b, nil)
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, c, nil)
+	snaps.update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil, d, nil)
+	snaps.update(common.HexToHash("0x06"), common.HexToHash("0x05"), nil, e, nil)
+	snaps.update(common.HexToHash("0x07"), common.HexToHash("0x06"), nil, f, nil)
+	snaps.update(common.HexToHash("0x08"), common.HexToHash("0x07"), nil, g, nil)
+	snaps.update(common.HexToHash("0x09"), common.HexToHash("0x08"), nil, h, nil)
 
 	it, _ := snaps.AccountIterator(common.HexToHash("0x09"), common.Hash{})
 	head := snaps.Snapshot(common.HexToHash("0x09"))
@@ -452,14 +452,14 @@ func TestStorageIteratorTraversalValues(t *testing.T) {
 		}
 	}
 	// Assemble a stack of snapshots from the account layers
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, randomAccountSet("0xaa"), wrapStorage(a))
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, randomAccountSet("0xaa"), wrapStorage(b))
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, randomAccountSet("0xaa"), wrapStorage(c))
-	snaps.Update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil, randomAccountSet("0xaa"), wrapStorage(d))
-	snaps.Update(common.HexToHash("0x06"), common.HexToHash("0x05"), nil, randomAccountSet("0xaa"), wrapStorage(e))
-	snaps.Update(common.HexToHash("0x07"), common.HexToHash("0x06"), nil, randomAccountSet("0xaa"), wrapStorage(e))
-	snaps.Update(common.HexToHash("0x08"), common.HexToHash("0x07"), nil, randomAccountSet("0xaa"), wrapStorage(g))
-	snaps.Update(common.HexToHash("0x09"), common.HexToHash("0x08"), nil, randomAccountSet("0xaa"), wrapStorage(h))
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, randomAccountSet("0xaa"), wrapStorage(a))
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, randomAccountSet("0xaa"), wrapStorage(b))
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, randomAccountSet("0xaa"), wrapStorage(c))
+	snaps.update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil, randomAccountSet("0xaa"), wrapStorage(d))
+	snaps.update(common.HexToHash("0x06"), common.HexToHash("0x05"), nil, randomAccountSet("0xaa"), wrapStorage(e))
+	snaps.update(common.HexToHash("0x07"), common.HexToHash("0x06"), nil, randomAccountSet("0xaa"), wrapStorage(e))
+	snaps.update(common.HexToHash("0x08"), common.HexToHash("0x07"), nil, randomAccountSet("0xaa"), wrapStorage(g))
+	snaps.update(common.HexToHash("0x09"), common.HexToHash("0x08"), nil, randomAccountSet("0xaa"), wrapStorage(h))
 
 	it, _ := snaps.StorageIterator(common.HexToHash("0x09"), common.HexToHash("0xaa"), common.Hash{})
 	head := snaps.Snapshot(common.HexToHash("0x09"))
@@ -522,7 +522,7 @@ func TestAccountIteratorLargeTraversal(t *testing.T) {
 		},
 	}
 	for i := 1; i < 128; i++ {
-		snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil)
+		snaps.update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil)
 	}
 	// Iterate the entire stack and ensure everything is hit only once
 	head := snaps.Snapshot(common.HexToHash("0x80"))
@@ -566,13 +566,13 @@ func TestAccountIteratorFlattening(t *testing.T) {
 		},
 	}
 	// Create a stack of diffs on top
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
 		randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
 
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
 		randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
 
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
 		randomAccountSet("0xcc", "0xf0", "0xff"), nil)
 
 	// Create an iterator and flatten the data from underneath it
@@ -597,13 +597,13 @@ func TestAccountIteratorSeek(t *testing.T) {
 			base.root: base,
 		},
 	}
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
 		randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
 
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
 		randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
 
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
 		randomAccountSet("0xcc", "0xf0", "0xff"), nil)
 
 	// Account set is now
@@ -661,13 +661,13 @@ func TestStorageIteratorSeek(t *testing.T) {
 		},
 	}
 	// Stack three diff layers on top with various overlaps
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x03", "0x05"}}, nil))
 
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x02", "0x05", "0x06"}}, nil))
 
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x05", "0x08"}}, nil))
 
 	// Account set is now
@@ -724,17 +724,17 @@ func TestAccountIteratorDeletions(t *testing.T) {
 		},
 	}
 	// Stack three diff layers on top with various overlaps
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"),
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"),
 		nil, randomAccountSet("0x11", "0x22", "0x33"), nil)
 
 	deleted := common.HexToHash("0x22")
 	destructed := map[common.Hash]struct{}{
 		deleted: {},
 	}
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"),
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"),
 		destructed, randomAccountSet("0x11", "0x33"), nil)
 
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"),
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"),
 		nil, randomAccountSet("0x33", "0x44", "0x55"), nil)
 
 	// The output should be 11,33,44,55
@@ -770,10 +770,10 @@ func TestStorageIteratorDeletions(t *testing.T) {
 		},
 	}
 	// Stack three diff layers on top with various overlaps
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x03", "0x05"}}, nil))
 
-	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
+	snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x02", "0x04", "0x06"}}, [][]string{{"0x01", "0x03"}}))
 
 	// The output should be 02,04,05,06
@@ -790,14 +790,14 @@ func TestStorageIteratorDeletions(t *testing.T) {
 	destructed := map[common.Hash]struct{}{
 		common.HexToHash("0xaa"): {},
 	}
-	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), destructed, nil, nil)
+	snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), destructed, nil, nil)
 
 	it, _ = snaps.StorageIterator(common.HexToHash("0x04"), common.HexToHash("0xaa"), common.Hash{})
 	verifyIterator(t, 0, it, verifyStorage)
 	it.Release()
 
 	// Re-insert the slots of the same account
-	snaps.Update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil,
+	snaps.update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil,
 		randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x07", "0x08", "0x09"}}, nil))
 
 	// The output should be 07,08,09
@@ -806,7 +806,7 @@ func TestStorageIteratorDeletions(t *testing.T) {
 	it.Release()
 
 	// Destruct the whole storage but re-create the account in the same layer
-	snaps.Update(common.HexToHash("0x06"), common.HexToHash("0x05"), destructed, randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x11", "0x12"}}, nil))
+	snaps.update(common.HexToHash("0x06"), common.HexToHash("0x05"), destructed, randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x11", "0x12"}}, nil))
 	it, _ = snaps.StorageIterator(common.HexToHash("0x06"), common.HexToHash("0xaa"), common.Hash{})
 	verifyIterator(t, 2, it, verifyStorage) // The output should be 11,12
 	it.Release()
@@ -848,7 +848,7 @@ func BenchmarkAccountIteratorTraversal(b *testing.B) {
 		},
 	}
 	for i := 1; i <= 100; i++ {
-		snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil)
+		snaps.update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil)
 	}
 	// We call this once before the benchmark, so the creation of
 	// sorted accountlists are not included in the results.
@@ -943,9 +943,9 @@ func BenchmarkAccountIteratorLargeBaselayer(b *testing.B) {
 			base.root: base,
 		},
 	}
-	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, makeAccounts(2000), nil)
+	snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, makeAccounts(2000), nil)
 	for i := 2; i <= 100; i++ {
-		snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(20), nil)
+		snaps.update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(20), nil)
 	}
 	// We call this once before the benchmark, so the creation of
 	// sorted accountlists are not included in the results.

+ 31 - 2
core/state/snapshot/snapshot.go

@@ -26,6 +26,7 @@ import (
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/metrics"
@@ -59,7 +60,6 @@ var (
 	snapshotDirtyStorageWriteMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/write", nil)
 
 	snapshotDirtyAccountHitDepthHist = metrics.NewRegisteredHistogram("state/snapshot/dirty/account/hit/depth", nil, metrics.NewExpDecaySample(1028, 0.015))
-	snapshotDirtyStorageHitDepthHist = metrics.NewRegisteredHistogram("state/snapshot/dirty/storage/hit/depth", nil, metrics.NewExpDecaySample(1028, 0.015))
 
 	snapshotFlushAccountItemMeter = metrics.NewRegisteredMeter("state/snapshot/flush/account/item", nil)
 	snapshotFlushAccountSizeMeter = metrics.NewRegisteredMeter("state/snapshot/flush/account/size", nil)
@@ -322,9 +322,14 @@ func (t *Tree) Snapshots(root common.Hash, limits int, nodisk bool) []Snapshot {
 	return ret
 }
 
+func (t *Tree) Update(blockRoot common.Hash, parentRoot common.Hash, destructs map[common.Address]struct{}, accounts map[common.Address][]byte, storage map[common.Address]map[string][]byte) error {
+	hashDestructs, hashAccounts, hashStorage := transformSnapData(destructs, accounts, storage)
+	return t.update(blockRoot, parentRoot, hashDestructs, hashAccounts, hashStorage)
+}
+
 // Update adds a new snapshot into the tree, if that can be linked to an existing
 // old parent. It is disallowed to insert a disk layer (the origin of all).
-func (t *Tree) Update(blockRoot common.Hash, parentRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) error {
+func (t *Tree) update(blockRoot common.Hash, parentRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) error {
 	// Reject noop updates to avoid self-loops in the snapshot tree. This is a
 	// special case that can only happen for Clique networks where empty blocks
 	// don't modify the state (0 block subsidy).
@@ -836,3 +841,27 @@ func (t *Tree) DiskRoot() common.Hash {
 
 	return t.diskRoot()
 }
+
+// TODO we can further improve it when the set is very large
+func transformSnapData(destructs map[common.Address]struct{}, accounts map[common.Address][]byte,
+	storage map[common.Address]map[string][]byte) (map[common.Hash]struct{}, map[common.Hash][]byte,
+	map[common.Hash]map[common.Hash][]byte) {
+	hasher := crypto.NewKeccakState()
+	hashDestructs := make(map[common.Hash]struct{}, len(destructs))
+	hashAccounts := make(map[common.Hash][]byte, len(accounts))
+	hashStorages := make(map[common.Hash]map[common.Hash][]byte, len(storage))
+	for addr := range destructs {
+		hashDestructs[crypto.Keccak256Hash(addr[:])] = struct{}{}
+	}
+	for addr, account := range accounts {
+		hashAccounts[crypto.Keccak256Hash(addr[:])] = account
+	}
+	for addr, accountStore := range storage {
+		hashStorage := make(map[common.Hash][]byte, len(accountStore))
+		for k, v := range accountStore {
+			hashStorage[crypto.HashData(hasher, []byte(k))] = v
+		}
+		hashStorages[crypto.Keccak256Hash(addr[:])] = hashStorage
+	}
+	return hashDestructs, hashAccounts, hashStorages
+}

+ 12 - 12
core/state/snapshot/snapshot_test.go

@@ -105,7 +105,7 @@ func TestDiskLayerExternalInvalidationFullFlatten(t *testing.T) {
 	accounts := map[common.Hash][]byte{
 		common.HexToHash("0xa1"): randomAccount(),
 	}
-	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
+	if err := snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
 	if n := len(snaps.layers); n != 2 {
@@ -149,10 +149,10 @@ func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) {
 	accounts := map[common.Hash][]byte{
 		common.HexToHash("0xa1"): randomAccount(),
 	}
-	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
+	if err := snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
-	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
+	if err := snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
 	if n := len(snaps.layers); n != 3 {
@@ -197,13 +197,13 @@ func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) {
 	accounts := map[common.Hash][]byte{
 		common.HexToHash("0xa1"): randomAccount(),
 	}
-	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
+	if err := snaps.update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
-	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
+	if err := snaps.update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
-	if err := snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, accounts, nil); err != nil {
+	if err := snaps.update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
 	if n := len(snaps.layers); n != 4 {
@@ -257,12 +257,12 @@ func TestPostCapBasicDataAccess(t *testing.T) {
 		},
 	}
 	// The lowest difflayer
-	snaps.Update(common.HexToHash("0xa1"), common.HexToHash("0x01"), nil, setAccount("0xa1"), nil)
-	snaps.Update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), nil, setAccount("0xa2"), nil)
-	snaps.Update(common.HexToHash("0xb2"), common.HexToHash("0xa1"), nil, setAccount("0xb2"), nil)
+	snaps.update(common.HexToHash("0xa1"), common.HexToHash("0x01"), nil, setAccount("0xa1"), nil)
+	snaps.update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), nil, setAccount("0xa2"), nil)
+	snaps.update(common.HexToHash("0xb2"), common.HexToHash("0xa1"), nil, setAccount("0xb2"), nil)
 
-	snaps.Update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), nil, setAccount("0xa3"), nil)
-	snaps.Update(common.HexToHash("0xb3"), common.HexToHash("0xb2"), nil, setAccount("0xb3"), nil)
+	snaps.update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), nil, setAccount("0xa3"), nil)
+	snaps.update(common.HexToHash("0xb3"), common.HexToHash("0xb2"), nil, setAccount("0xb3"), nil)
 
 	// checkExist verifies if an account exiss in a snapshot
 	checkExist := func(layer *diffLayer, key string) error {
@@ -357,7 +357,7 @@ func TestSnaphots(t *testing.T) {
 	)
 	for i := 0; i < 129; i++ {
 		head = makeRoot(uint64(i + 2))
-		snaps.Update(head, last, nil, setAccount(fmt.Sprintf("%d", i+2)), nil)
+		snaps.update(head, last, nil, setAccount(fmt.Sprintf("%d", i+2)), nil)
 		last = head
 		snaps.Cap(head, 128) // 130 layers (128 diffs + 1 accumulator + 1 disk)
 	}

+ 6 - 7
core/state/state_object.go

@@ -234,7 +234,7 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) common.Has
 		//   1) resurrect happened, and new slot values were set -- those should
 		//      have been handles via pendingStorage above.
 		//   2) we don't have new values, and can deliver empty response back
-		if _, destructed := s.db.snapDestructs[s.addrHash]; destructed {
+		if _, destructed := s.db.snapDestructs[s.address]; destructed {
 			return common.Hash{}
 		}
 		enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes()))
@@ -345,10 +345,9 @@ func (s *StateObject) updateTrie(db Database) Trie {
 		}(time.Now())
 	}
 	// The snapshot storage map for the object
-	var storage map[common.Hash][]byte
+	var storage map[string][]byte
 	// Insert all the pending updates into the trie
 	tr := s.getTrie(db)
-	hasher := s.db.hasher
 
 	usedStorage := make([][]byte, 0, len(s.pendingStorage))
 	for key, value := range s.pendingStorage {
@@ -371,12 +370,12 @@ func (s *StateObject) updateTrie(db Database) Trie {
 			s.db.snapMux.Lock()
 			if storage == nil {
 				// Retrieve the old storage map, if available, create a new one otherwise
-				if storage = s.db.snapStorage[s.addrHash]; storage == nil {
-					storage = make(map[common.Hash][]byte)
-					s.db.snapStorage[s.addrHash] = storage
+				if storage = s.db.snapStorage[s.address]; storage == nil {
+					storage = make(map[string][]byte)
+					s.db.snapStorage[s.address] = storage
 				}
 			}
-			storage[crypto.HashData(hasher, key[:])] = v // v will be nil if value is 0x00
+			storage[string(key[:])] = v // v will be nil if value is 0x00
 			s.db.snapMux.Unlock()
 		}
 		usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure

+ 1 - 1
core/state/state_test.go

@@ -167,7 +167,7 @@ func TestSnapshot2(t *testing.T) {
 	so0.deleted = false
 	state.SetStateObject(so0)
 
-	root, _ := state.Commit(false)
+	root, _, _ := state.Commit(false)
 	state, _ = New(root, state.db, state.snaps)
 
 	// and one with deleted == true

+ 273 - 36
core/state/statedb.go

@@ -27,6 +27,7 @@ import (
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/common/gopool"
 	"github.com/ethereum/go-ethereum/core/rawdb"
 	"github.com/ethereum/go-ethereum/core/state/snapshot"
 	"github.com/ethereum/go-ethereum/core/types"
@@ -39,7 +40,7 @@ import (
 )
 
 const (
-	preLoadLimit      = 64
+	preLoadLimit      = 128
 	defaultNumOfSlots = 100
 )
 
@@ -72,18 +73,22 @@ func (n *proofList) Delete(key []byte) error {
 // * Contracts
 // * Accounts
 type StateDB struct {
-	db           Database
-	prefetcher   *triePrefetcher
-	originalRoot common.Hash // The pre-state root, before any changes were made
-	trie         Trie
-	hasher       crypto.KeccakState
+	db             Database
+	prefetcher     *triePrefetcher
+	originalRoot   common.Hash // The pre-state root, before any changes were made
+	trie           Trie
+	hasher         crypto.KeccakState
+	diffLayer      *types.DiffLayer
+	diffTries      map[common.Address]Trie
+	diffCode       map[common.Hash][]byte
+	lightProcessed bool
 
 	snapMux       sync.Mutex
 	snaps         *snapshot.Tree
 	snap          snapshot.Snapshot
-	snapDestructs map[common.Hash]struct{}
-	snapAccounts  map[common.Hash][]byte
-	snapStorage   map[common.Hash]map[common.Hash][]byte
+	snapDestructs map[common.Address]struct{}
+	snapAccounts  map[common.Address][]byte
+	snapStorage   map[common.Address]map[string][]byte
 
 	// This map holds 'live' objects, which will get modified while processing a state transition.
 	stateObjects        map[common.Address]*StateObject
@@ -156,9 +161,9 @@ func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB,
 	sdb.trie = tr
 	if sdb.snaps != nil {
 		if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil {
-			sdb.snapDestructs = make(map[common.Hash]struct{})
-			sdb.snapAccounts = make(map[common.Hash][]byte)
-			sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
+			sdb.snapDestructs = make(map[common.Address]struct{})
+			sdb.snapAccounts = make(map[common.Address][]byte)
+			sdb.snapStorage = make(map[common.Address]map[string][]byte)
 		}
 	}
 	return sdb, nil
@@ -186,6 +191,15 @@ func (s *StateDB) StopPrefetcher() {
 	}
 }
 
+// Mark that the block is processed by diff layer
+func (s *StateDB) MarkLightProcessed() {
+	s.lightProcessed = true
+}
+
+func (s *StateDB) IsLightProcessed() bool {
+	return s.lightProcessed
+}
+
 // setError remembers the first non-nil error it is called with.
 func (s *StateDB) setError(err error) {
 	if s.dbErr == nil {
@@ -197,6 +211,19 @@ func (s *StateDB) Error() error {
 	return s.dbErr
 }
 
+func (s *StateDB) Trie() Trie {
+	return s.trie
+}
+
+func (s *StateDB) SetDiff(diffLayer *types.DiffLayer, diffTries map[common.Address]Trie, diffCode map[common.Hash][]byte) {
+	s.diffLayer, s.diffTries, s.diffCode = diffLayer, diffTries, diffCode
+}
+
+func (s *StateDB) SetSnapData(snapDestructs map[common.Address]struct{}, snapAccounts map[common.Address][]byte,
+	snapStorage map[common.Address]map[string][]byte) {
+	s.snapDestructs, s.snapAccounts, s.snapStorage = snapDestructs, snapAccounts, snapStorage
+}
+
 func (s *StateDB) AddLog(log *types.Log) {
 	s.journal.append(addLogChange{txhash: s.thash})
 
@@ -532,7 +559,7 @@ func (s *StateDB) TryPreload(block *types.Block, signer types.Signer) {
 			accounts[*tx.To()] = true
 		}
 	}
-	for account, _ := range accounts {
+	for account := range accounts {
 		accountsSlice = append(accountsSlice, account)
 	}
 	if len(accountsSlice) >= preLoadLimit && len(accountsSlice) > runtime.NumCPU() {
@@ -550,10 +577,8 @@ func (s *StateDB) TryPreload(block *types.Block, signer types.Signer) {
 		}
 		for i := 0; i < runtime.NumCPU(); i++ {
 			objs := <-objsChan
-			if objs != nil {
-				for _, obj := range objs {
-					s.SetStateObject(obj)
-				}
+			for _, obj := range objs {
+				s.SetStateObject(obj)
 			}
 		}
 	}
@@ -683,9 +708,9 @@ func (s *StateDB) createObject(addr common.Address) (newobj, prev *StateObject)
 
 	var prevdestruct bool
 	if s.snap != nil && prev != nil {
-		_, prevdestruct = s.snapDestructs[prev.addrHash]
+		_, prevdestruct = s.snapDestructs[prev.address]
 		if !prevdestruct {
-			s.snapDestructs[prev.addrHash] = struct{}{}
+			s.snapDestructs[prev.address] = struct{}{}
 		}
 	}
 	newobj = newObject(s, addr, Account{})
@@ -830,17 +855,17 @@ func (s *StateDB) Copy() *StateDB {
 		state.snaps = s.snaps
 		state.snap = s.snap
 		// deep copy needed
-		state.snapDestructs = make(map[common.Hash]struct{})
+		state.snapDestructs = make(map[common.Address]struct{})
 		for k, v := range s.snapDestructs {
 			state.snapDestructs[k] = v
 		}
-		state.snapAccounts = make(map[common.Hash][]byte)
+		state.snapAccounts = make(map[common.Address][]byte)
 		for k, v := range s.snapAccounts {
 			state.snapAccounts[k] = v
 		}
-		state.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
+		state.snapStorage = make(map[common.Address]map[string][]byte)
 		for k, v := range s.snapStorage {
-			temp := make(map[common.Hash][]byte)
+			temp := make(map[string][]byte)
 			for kk, vv := range v {
 				temp[kk] = vv
 			}
@@ -903,9 +928,9 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) {
 			// transactions within the same block might self destruct and then
 			// ressurrect an account; but the snapshotter needs both events.
 			if s.snap != nil {
-				s.snapDestructs[obj.addrHash] = struct{}{} // We need to maintain account deletions explicitly (will remain set indefinitely)
-				delete(s.snapAccounts, obj.addrHash)       // Clear out any previously updated account data (may be recreated via a ressurrect)
-				delete(s.snapStorage, obj.addrHash)        // Clear out any previously updated storage data (may be recreated via a ressurrect)
+				s.snapDestructs[obj.address] = struct{}{} // We need to maintain account deletions explicitly (will remain set indefinitely)
+				delete(s.snapAccounts, obj.address)       // Clear out any previously updated account data (may be recreated via a ressurrect)
+				delete(s.snapStorage, obj.address)        // Clear out any previously updated storage data (may be recreated via a ressurrect)
 			}
 		} else {
 			obj.finalise(true) // Prefetch slots in the background
@@ -932,6 +957,9 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) {
 // It is called in between transactions to get the root hash that
 // goes into transaction receipts.
 func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
+	if s.lightProcessed {
+		return s.trie.Hash()
+	}
 	// Finalise all the dirty storage states and write them into the tries
 	s.Finalise(deleteEmptyObjects)
 
@@ -983,7 +1011,8 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
 				// at transaction boundary level to ensure we capture state clearing.
 				if s.snap != nil && !obj.deleted {
 					s.snapMux.Lock()
-					s.snapAccounts[obj.addrHash] = snapshot.SlimAccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash)
+					// It is possible to add unnecessary change, but it is fine.
+					s.snapAccounts[obj.address] = snapshot.SlimAccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash)
 					s.snapMux.Unlock()
 				}
 				data, err := rlp.EncodeToBytes(obj)
@@ -1007,7 +1036,7 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
 	if s.trie == nil {
 		tr, err := s.db.OpenTrie(s.originalRoot)
 		if err != nil {
-			panic(fmt.Sprintf("Failed to open trie tree"))
+			panic("Failed to open trie tree")
 		}
 		s.trie = tr
 	}
@@ -1051,14 +1080,143 @@ func (s *StateDB) clearJournalAndRefund() {
 	s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entires
 }
 
+func (s *StateDB) LightCommit(root common.Hash) (common.Hash, *types.DiffLayer, error) {
+	codeWriter := s.db.TrieDB().DiskDB().NewBatch()
+
+	commitFuncs := []func() error{
+		func() error {
+			for codeHash, code := range s.diffCode {
+				rawdb.WriteCode(codeWriter, codeHash, code)
+				if codeWriter.ValueSize() >= ethdb.IdealBatchSize {
+					if err := codeWriter.Write(); err != nil {
+						return err
+					}
+					codeWriter.Reset()
+				}
+			}
+			if codeWriter.ValueSize() > 0 {
+				if err := codeWriter.Write(); err != nil {
+					return err
+				}
+			}
+			return nil
+		},
+		func() error {
+			tasks := make(chan func())
+			taskResults := make(chan error, len(s.diffTries))
+			tasksNum := 0
+			finishCh := make(chan struct{})
+			defer close(finishCh)
+			threads := gopool.Threads(len(s.diffTries))
+
+			for i := 0; i < threads; i++ {
+				go func() {
+					for {
+						select {
+						case task := <-tasks:
+							task()
+						case <-finishCh:
+							return
+						}
+					}
+				}()
+			}
+
+			for account, diff := range s.diffTries {
+				tmpAccount := account
+				tmpDiff := diff
+				tasks <- func() {
+					root, err := tmpDiff.Commit(nil)
+					if err != nil {
+						taskResults <- err
+						return
+					}
+					s.db.CacheStorage(crypto.Keccak256Hash(tmpAccount[:]), root, tmpDiff)
+					taskResults <- nil
+				}
+				tasksNum++
+			}
+
+			for i := 0; i < tasksNum; i++ {
+				err := <-taskResults
+				if err != nil {
+					return err
+				}
+			}
+
+			// commit account trie
+			var account Account
+			root, err := s.trie.Commit(func(_ [][]byte, _ []byte, leaf []byte, parent common.Hash) error {
+				if err := rlp.DecodeBytes(leaf, &account); err != nil {
+					return nil
+				}
+				if account.Root != emptyRoot {
+					s.db.TrieDB().Reference(account.Root, parent)
+				}
+				return nil
+			})
+			if err != nil {
+				return err
+			}
+			if root != emptyRoot {
+				s.db.CacheAccount(root, s.trie)
+			}
+			return nil
+		},
+		func() error {
+			if s.snap != nil {
+				if metrics.EnabledExpensive {
+					defer func(start time.Time) { s.SnapshotCommits += time.Since(start) }(time.Now())
+				}
+				// Only update if there's a state transition (skip empty Clique blocks)
+				if parent := s.snap.Root(); parent != root {
+					if err := s.snaps.Update(root, parent, s.snapDestructs, s.snapAccounts, s.snapStorage); err != nil {
+						log.Warn("Failed to update snapshot tree", "from", parent, "to", root, "err", err)
+					}
+					// Keep n diff layers in the memory
+					// - head layer is paired with HEAD state
+					// - head-1 layer is paired with HEAD-1 state
+					// - head-(n-1) layer(bottom-most diff layer) is paired with HEAD-(n-1)state
+					if err := s.snaps.Cap(root, s.snaps.CapLimit()); err != nil {
+						log.Warn("Failed to cap snapshot tree", "root", root, "layers", s.snaps.CapLimit(), "err", err)
+					}
+				}
+			}
+			return nil
+		},
+	}
+	commitRes := make(chan error, len(commitFuncs))
+	for _, f := range commitFuncs {
+		tmpFunc := f
+		go func() {
+			commitRes <- tmpFunc()
+		}()
+	}
+	for i := 0; i < len(commitFuncs); i++ {
+		r := <-commitRes
+		if r != nil {
+			return common.Hash{}, nil, r
+		}
+	}
+	s.snap, s.snapDestructs, s.snapAccounts, s.snapStorage = nil, nil, nil, nil
+	s.diffTries, s.diffCode = nil, nil
+	return root, s.diffLayer, nil
+}
+
 // Commit writes the state to the underlying in-memory trie database.
-func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
+func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, *types.DiffLayer, error) {
 	if s.dbErr != nil {
-		return common.Hash{}, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr)
+		return common.Hash{}, nil, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr)
 	}
 	// Finalize any pending changes and merge everything into the tries
 	root := s.IntermediateRoot(deleteEmptyObjects)
-
+	if s.lightProcessed {
+		return s.LightCommit(root)
+	}
+	var diffLayer *types.DiffLayer
+	if s.snap != nil {
+		diffLayer = &types.DiffLayer{}
+	}
 	commitFuncs := []func() error{
 		func() error {
 			// Commit objects to the trie, measuring the elapsed time
@@ -1066,9 +1224,13 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
 			taskResults := make(chan error, len(s.stateObjectsDirty))
 			tasksNum := 0
 			finishCh := make(chan struct{})
-			defer close(finishCh)
-			for i := 0; i < runtime.NumCPU(); i++ {
+
+			threads := gopool.Threads(len(s.stateObjectsDirty))
+			wg := sync.WaitGroup{}
+			for i := 0; i < threads; i++ {
+				wg.Add(1)
 				go func() {
+					defer wg.Done()
 					codeWriter := s.db.TrieDB().DiskDB().NewBatch()
 					for {
 						select {
@@ -1086,6 +1248,19 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
 				}()
 			}
 
+			if s.snap != nil {
+				for addr := range s.stateObjectsDirty {
+					if obj := s.stateObjects[addr]; !obj.deleted {
+						if obj.code != nil && obj.dirtyCode {
+							diffLayer.Codes = append(diffLayer.Codes, types.DiffCode{
+								Hash: common.BytesToHash(obj.CodeHash()),
+								Code: obj.code,
+							})
+						}
+					}
+				}
+			}
+
 			for addr := range s.stateObjectsDirty {
 				if obj := s.stateObjects[addr]; !obj.deleted {
 					// Write any contract code associated with the state object
@@ -1107,9 +1282,11 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
 			for i := 0; i < tasksNum; i++ {
 				err := <-taskResults
 				if err != nil {
+					close(finishCh)
 					return err
 				}
 			}
+			close(finishCh)
 
 			if len(s.stateObjectsDirty) > 0 {
 				s.stateObjectsDirty = make(map[common.Address]struct{}, len(s.stateObjectsDirty)/2)
@@ -1140,6 +1317,7 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
 			if root != emptyRoot {
 				s.db.CacheAccount(root, s.trie)
 			}
+			wg.Wait()
 			return nil
 		},
 		func() error {
@@ -1161,7 +1339,12 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
 						log.Warn("Failed to cap snapshot tree", "root", root, "layers", s.snaps.CapLimit(), "err", err)
 					}
 				}
-				s.snap, s.snapDestructs, s.snapAccounts, s.snapStorage = nil, nil, nil, nil
+			}
+			return nil
+		},
+		func() error {
+			if s.snap != nil {
+				diffLayer.Destructs, diffLayer.Accounts, diffLayer.Storages = s.SnapToDiffLayer()
 			}
 			return nil
 		},
@@ -1176,11 +1359,65 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
 	for i := 0; i < len(commitFuncs); i++ {
 		r := <-commitRes
 		if r != nil {
-			return common.Hash{}, r
+			return common.Hash{}, nil, r
 		}
 	}
+	s.snap, s.snapDestructs, s.snapAccounts, s.snapStorage = nil, nil, nil, nil
+	return root, diffLayer, nil
+}
 
-	return root, nil
+func (s *StateDB) DiffLayerToSnap(diffLayer *types.DiffLayer) (map[common.Address]struct{}, map[common.Address][]byte, map[common.Address]map[string][]byte, error) {
+	snapDestructs := make(map[common.Address]struct{})
+	snapAccounts := make(map[common.Address][]byte)
+	snapStorage := make(map[common.Address]map[string][]byte)
+
+	for _, des := range diffLayer.Destructs {
+		snapDestructs[des] = struct{}{}
+	}
+	for _, account := range diffLayer.Accounts {
+		snapAccounts[account.Account] = account.Blob
+	}
+	for _, storage := range diffLayer.Storages {
+		// should never happen
+		if len(storage.Keys) != len(storage.Vals) {
+			return nil, nil, nil, errors.New("invalid diffLayer: length of keys and values mismatch")
+		}
+		snapStorage[storage.Account] = make(map[string][]byte, len(storage.Keys))
+		n := len(storage.Keys)
+		for i := 0; i < n; i++ {
+			snapStorage[storage.Account][storage.Keys[i]] = storage.Vals[i]
+		}
+	}
+	return snapDestructs, snapAccounts, snapStorage, nil
+}
+
+func (s *StateDB) SnapToDiffLayer() ([]common.Address, []types.DiffAccount, []types.DiffStorage) {
+	destructs := make([]common.Address, 0, len(s.snapDestructs))
+	for account := range s.snapDestructs {
+		destructs = append(destructs, account)
+	}
+	accounts := make([]types.DiffAccount, 0, len(s.snapAccounts))
+	for accountHash, account := range s.snapAccounts {
+		accounts = append(accounts, types.DiffAccount{
+			Account: accountHash,
+			Blob:    account,
+		})
+	}
+	storages := make([]types.DiffStorage, 0, len(s.snapStorage))
+	for accountHash, storage := range s.snapStorage {
+		keys := make([]string, 0, len(storage))
+		values := make([][]byte, 0, len(storage))
+		for k, v := range storage {
+			keys = append(keys, k)
+			values = append(values, v)
+		}
+		storages = append(storages, types.DiffStorage{
+			Account: accountHash,
+			Keys:    keys,
+			Vals:    values,
+		})
+	}
+	return destructs, accounts, storages
 }
 
 // PrepareAccessList handles the preparatory steps for executing a state transition with

+ 7 - 7
core/state/statedb_test.go

@@ -102,7 +102,7 @@ func TestIntermediateLeaks(t *testing.T) {
 	}
 
 	// Commit and cross check the databases.
-	transRoot, err := transState.Commit(false)
+	transRoot, _, err := transState.Commit(false)
 	if err != nil {
 		t.Fatalf("failed to commit transition state: %v", err)
 	}
@@ -110,7 +110,7 @@ func TestIntermediateLeaks(t *testing.T) {
 		t.Errorf("can not commit trie %v to persistent database", transRoot.Hex())
 	}
 
-	finalRoot, err := finalState.Commit(false)
+	finalRoot, _, err := finalState.Commit(false)
 	if err != nil {
 		t.Fatalf("failed to commit final state: %v", err)
 	}
@@ -473,7 +473,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
 func TestTouchDelete(t *testing.T) {
 	s := newStateTest()
 	s.state.GetOrNewStateObject(common.Address{})
-	root, _ := s.state.Commit(false)
+	root, _, _ := s.state.Commit(false)
 	s.state, _ = New(root, s.state.db, s.state.snaps)
 
 	snapshot := s.state.Snapshot()
@@ -675,7 +675,7 @@ func TestDeleteCreateRevert(t *testing.T) {
 	addr := common.BytesToAddress([]byte("so"))
 	state.SetBalance(addr, big.NewInt(1))
 
-	root, _ := state.Commit(false)
+	root, _, _ := state.Commit(false)
 	state, _ = New(root, state.db, state.snaps)
 
 	// Simulate self-destructing in one transaction, then create-reverting in another
@@ -687,7 +687,7 @@ func TestDeleteCreateRevert(t *testing.T) {
 	state.RevertToSnapshot(id)
 
 	// Commit the entire state and make sure we don't crash and have the correct state
-	root, _ = state.Commit(true)
+	root, _, _ = state.Commit(true)
 	state, _ = New(root, state.db, state.snaps)
 
 	if state.getStateObject(addr) != nil {
@@ -712,7 +712,7 @@ func TestMissingTrieNodes(t *testing.T) {
 		a2 := common.BytesToAddress([]byte("another"))
 		state.SetBalance(a2, big.NewInt(100))
 		state.SetCode(a2, []byte{1, 2, 4})
-		root, _ = state.Commit(false)
+		root, _, _ = state.Commit(false)
 		t.Logf("root: %x", root)
 		// force-flush
 		state.Database().TrieDB().Cap(0)
@@ -736,7 +736,7 @@ func TestMissingTrieNodes(t *testing.T) {
 	}
 	// Modify the state
 	state.SetBalance(addr, big.NewInt(2))
-	root, err := state.Commit(false)
+	root, _, err := state.Commit(false)
 	if err == nil {
 		t.Fatalf("expected error, got root :%x", root)
 	}

+ 1 - 1
core/state/sync_test.go

@@ -69,7 +69,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) {
 		state.updateStateObject(obj)
 		accounts = append(accounts, acc)
 	}
-	root, _ := state.Commit(false)
+	root, _, _ := state.Commit(false)
 
 	// Return the generated state
 	return db, root, accounts

+ 0 - 9
core/state_prefetcher.go

@@ -35,15 +35,6 @@ type statePrefetcher struct {
 	engine consensus.Engine    // Consensus engine used for block rewards
 }
 
-// newStatePrefetcher initialises a new statePrefetcher.
-func newStatePrefetcher(config *params.ChainConfig, bc *BlockChain, engine consensus.Engine) *statePrefetcher {
-	return &statePrefetcher{
-		config: config,
-		bc:     bc,
-		engine: engine,
-	}
-}
-
 // Prefetch processes the state changes according to the Ethereum rules by running
 // the transaction messages using the statedb, but any changes are discarded. The
 // only goal is to pre-cache transaction signatures and state trie nodes.

+ 322 - 7
core/state_processor.go

@@ -17,17 +17,36 @@
 package core
 
 import (
+	"bytes"
+	"errors"
 	"fmt"
+	"math/big"
+	"math/rand"
+	"sync"
+	"time"
 
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/common/gopool"
 	"github.com/ethereum/go-ethereum/consensus"
 	"github.com/ethereum/go-ethereum/consensus/misc"
+	"github.com/ethereum/go-ethereum/core/rawdb"
 	"github.com/ethereum/go-ethereum/core/state"
+	"github.com/ethereum/go-ethereum/core/state/snapshot"
 	"github.com/ethereum/go-ethereum/core/systemcontracts"
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/core/vm"
 	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/params"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+const (
+	fullProcessCheck          = 21 // On diff sync mode, will do full process every fullProcessCheck randomly
+	minNumberOfAccountPerTask = 5
+	recentTime                = 2048 * 3
+	recentDiffLayerTimeout    = 20
+	farDiffLayerTimeout       = 2
 )
 
 // StateProcessor is a basic Processor, which takes care of transitioning
@@ -49,6 +68,301 @@ func NewStateProcessor(config *params.ChainConfig, bc *BlockChain, engine consen
 	}
 }
 
+type LightStateProcessor struct {
+	randomGenerator *rand.Rand
+	StateProcessor
+}
+
+func NewLightStateProcessor(config *params.ChainConfig, bc *BlockChain, engine consensus.Engine) *LightStateProcessor {
+	randomGenerator := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
+	return &LightStateProcessor{
+		randomGenerator: randomGenerator,
+		StateProcessor:  *NewStateProcessor(config, bc, engine),
+	}
+}
+
+func (p *LightStateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg vm.Config) (*state.StateDB, types.Receipts, []*types.Log, uint64, error) {
+	allowLightProcess := true
+	if posa, ok := p.engine.(consensus.PoSA); ok {
+		allowLightProcess = posa.AllowLightProcess(p.bc, block.Header())
+	}
+	// random fallback to full process
+	if check := p.randomGenerator.Int63n(fullProcessCheck); allowLightProcess && check != 0 && len(block.Transactions()) != 0 {
+		var pid string
+		if peer, ok := block.ReceivedFrom.(PeerIDer); ok {
+			pid = peer.ID()
+		}
+		var diffLayer *types.DiffLayer
+		var diffLayerTimeout = recentDiffLayerTimeout
+		if time.Now().Unix()-int64(block.Time()) > recentTime {
+			diffLayerTimeout = farDiffLayerTimeout
+		}
+		for tried := 0; tried < diffLayerTimeout; tried++ {
+			// wait a bit for the diff layer
+			diffLayer = p.bc.GetUnTrustedDiffLayer(block.Hash(), pid)
+			if diffLayer != nil {
+				break
+			}
+			time.Sleep(time.Millisecond)
+		}
+		if diffLayer != nil {
+			if err := diffLayer.Receipts.DeriveFields(p.bc.chainConfig, block.Hash(), block.NumberU64(), block.Transactions()); err != nil {
+				log.Error("Failed to derive block receipts fields", "hash", block.Hash(), "number", block.NumberU64(), "err", err)
+				// fallback to full process
+				return p.StateProcessor.Process(block, statedb, cfg)
+			}
+
+			receipts, logs, gasUsed, err := p.LightProcess(diffLayer, block, statedb)
+			if err == nil {
+				log.Info("do light process success at block", "num", block.NumberU64())
+				return statedb, receipts, logs, gasUsed, nil
+			}
+			log.Error("do light process err at block", "num", block.NumberU64(), "err", err)
+			p.bc.removeDiffLayers(diffLayer.DiffHash)
+			// prepare new statedb
+			statedb.StopPrefetcher()
+			parent := p.bc.GetHeader(block.ParentHash(), block.NumberU64()-1)
+			statedb, err = state.New(parent.Root, p.bc.stateCache, p.bc.snaps)
+			if err != nil {
+				return statedb, nil, nil, 0, err
+			}
+			// Enable prefetching to pull in trie node paths while processing transactions
+			statedb.StartPrefetcher("chain")
+		}
+	}
+	// fallback to full process
+	return p.StateProcessor.Process(block, statedb, cfg)
+}
+
+func (p *LightStateProcessor) LightProcess(diffLayer *types.DiffLayer, block *types.Block, statedb *state.StateDB) (types.Receipts, []*types.Log, uint64, error) {
+	statedb.MarkLightProcessed()
+	fullDiffCode := make(map[common.Hash][]byte, len(diffLayer.Codes))
+	diffTries := make(map[common.Address]state.Trie)
+	diffCode := make(map[common.Hash][]byte)
+
+	snapDestructs, snapAccounts, snapStorage, err := statedb.DiffLayerToSnap(diffLayer)
+	if err != nil {
+		return nil, nil, 0, err
+	}
+
+	for _, c := range diffLayer.Codes {
+		fullDiffCode[c.Hash] = c.Code
+	}
+
+	for des := range snapDestructs {
+		statedb.Trie().TryDelete(des[:])
+	}
+	threads := gopool.Threads(len(snapAccounts))
+
+	iteAccounts := make([]common.Address, 0, len(snapAccounts))
+	for diffAccount := range snapAccounts {
+		iteAccounts = append(iteAccounts, diffAccount)
+	}
+
+	errChan := make(chan error, threads)
+	exitChan := make(chan struct{})
+	var snapMux sync.RWMutex
+	var stateMux, diffMux sync.Mutex
+	for i := 0; i < threads; i++ {
+		start := i * len(iteAccounts) / threads
+		end := (i + 1) * len(iteAccounts) / threads
+		if i+1 == threads {
+			end = len(iteAccounts)
+		}
+		go func(start, end int) {
+			for index := start; index < end; index++ {
+				select {
+				// fast fail
+				case <-exitChan:
+					return
+				default:
+				}
+				diffAccount := iteAccounts[index]
+				snapMux.RLock()
+				blob := snapAccounts[diffAccount]
+				snapMux.RUnlock()
+				addrHash := crypto.Keccak256Hash(diffAccount[:])
+				latestAccount, err := snapshot.FullAccount(blob)
+				if err != nil {
+					errChan <- err
+					return
+				}
+
+				// fetch previous state
+				var previousAccount state.Account
+				stateMux.Lock()
+				enc, err := statedb.Trie().TryGet(diffAccount[:])
+				stateMux.Unlock()
+				if err != nil {
+					errChan <- err
+					return
+				}
+				if len(enc) != 0 {
+					if err := rlp.DecodeBytes(enc, &previousAccount); err != nil {
+						errChan <- err
+						return
+					}
+				}
+				if latestAccount.Balance == nil {
+					latestAccount.Balance = new(big.Int)
+				}
+				if previousAccount.Balance == nil {
+					previousAccount.Balance = new(big.Int)
+				}
+				if previousAccount.Root == (common.Hash{}) {
+					previousAccount.Root = types.EmptyRootHash
+				}
+				if len(previousAccount.CodeHash) == 0 {
+					previousAccount.CodeHash = types.EmptyCodeHash
+				}
+
+				// skip no change account
+				if previousAccount.Nonce == latestAccount.Nonce &&
+					bytes.Equal(previousAccount.CodeHash, latestAccount.CodeHash) &&
+					previousAccount.Balance.Cmp(latestAccount.Balance) == 0 &&
+					previousAccount.Root == common.BytesToHash(latestAccount.Root) {
+					// It is normal to receive redundant message since the collected message is redundant.
+					log.Debug("receive redundant account change in diff layer", "account", diffAccount, "num", block.NumberU64())
+					snapMux.Lock()
+					delete(snapAccounts, diffAccount)
+					delete(snapStorage, diffAccount)
+					snapMux.Unlock()
+					continue
+				}
+
+				// update code
+				codeHash := common.BytesToHash(latestAccount.CodeHash)
+				if !bytes.Equal(latestAccount.CodeHash, previousAccount.CodeHash) &&
+					!bytes.Equal(latestAccount.CodeHash, types.EmptyCodeHash) {
+					if code, exist := fullDiffCode[codeHash]; exist {
+						if crypto.Keccak256Hash(code) != codeHash {
+							errChan <- fmt.Errorf("code and code hash mismatch, account %s", diffAccount.String())
+							return
+						}
+						diffMux.Lock()
+						diffCode[codeHash] = code
+						diffMux.Unlock()
+					} else {
+						rawCode := rawdb.ReadCode(p.bc.db, codeHash)
+						if len(rawCode) == 0 {
+							errChan <- fmt.Errorf("missing code, account %s", diffAccount.String())
+							return
+						}
+					}
+				}
+
+				//update storage
+				latestRoot := common.BytesToHash(latestAccount.Root)
+				if latestRoot != previousAccount.Root && latestRoot != types.EmptyRootHash {
+					accountTrie, err := statedb.Database().OpenStorageTrie(addrHash, previousAccount.Root)
+					if err != nil {
+						errChan <- err
+						return
+					}
+					snapMux.RLock()
+					storageChange, exist := snapStorage[diffAccount]
+					snapMux.RUnlock()
+
+					if !exist {
+						errChan <- errors.New("missing storage change in difflayer")
+						return
+					}
+					for k, v := range storageChange {
+						if len(v) != 0 {
+							accountTrie.TryUpdate([]byte(k), v)
+						} else {
+							accountTrie.TryDelete([]byte(k))
+						}
+					}
+
+					// check storage root
+					accountRootHash := accountTrie.Hash()
+					if latestRoot != accountRootHash {
+						errChan <- errors.New("account storage root mismatch")
+						return
+					}
+					diffMux.Lock()
+					diffTries[diffAccount] = accountTrie
+					diffMux.Unlock()
+				} else {
+					snapMux.Lock()
+					delete(snapStorage, diffAccount)
+					snapMux.Unlock()
+				}
+
+				// can't trust the blob, need encode by our-self.
+				latestStateAccount := state.Account{
+					Nonce:    latestAccount.Nonce,
+					Balance:  latestAccount.Balance,
+					Root:     common.BytesToHash(latestAccount.Root),
+					CodeHash: latestAccount.CodeHash,
+				}
+				bz, err := rlp.EncodeToBytes(&latestStateAccount)
+				if err != nil {
+					errChan <- err
+					return
+				}
+				stateMux.Lock()
+				err = statedb.Trie().TryUpdate(diffAccount[:], bz)
+				stateMux.Unlock()
+				if err != nil {
+					errChan <- err
+					return
+				}
+			}
+			errChan <- nil
+		}(start, end)
+	}
+
+	for i := 0; i < threads; i++ {
+		err := <-errChan
+		if err != nil {
+			close(exitChan)
+			return nil, nil, 0, err
+		}
+	}
+
+	var allLogs []*types.Log
+	var gasUsed uint64
+	for _, receipt := range diffLayer.Receipts {
+		allLogs = append(allLogs, receipt.Logs...)
+		gasUsed += receipt.GasUsed
+	}
+
+	// Do validate in advance so that we can fall back to full process
+	if err := p.bc.validator.ValidateState(block, statedb, diffLayer.Receipts, gasUsed); err != nil {
+		log.Error("validate state failed during diff sync", "error", err)
+		return nil, nil, 0, err
+	}
+
+	// remove redundant storage change
+	for account := range snapStorage {
+		if _, exist := snapAccounts[account]; !exist {
+			log.Warn("receive redundant storage change in diff layer")
+			delete(snapStorage, account)
+		}
+	}
+
+	// remove redundant code
+	if len(fullDiffCode) != len(diffLayer.Codes) {
+		diffLayer.Codes = make([]types.DiffCode, 0, len(diffCode))
+		for hash, code := range diffCode {
+			diffLayer.Codes = append(diffLayer.Codes, types.DiffCode{
+				Hash: hash,
+				Code: code,
+			})
+		}
+	}
+
+	statedb.SetSnapData(snapDestructs, snapAccounts, snapStorage)
+	if len(snapAccounts) != len(diffLayer.Accounts) || len(snapStorage) != len(diffLayer.Storages) {
+		diffLayer.Destructs, diffLayer.Accounts, diffLayer.Storages = statedb.SnapToDiffLayer()
+	}
+	statedb.SetDiff(diffLayer, diffTries, diffCode)
+
+	return diffLayer.Receipts, allLogs, gasUsed, nil
+}
+
 // Process processes the state changes according to the Ethereum rules by running
 // the transaction messages using the statedb and applying any rewards to both
 // the processor (coinbase) and any included uncles.
@@ -56,13 +370,15 @@ func NewStateProcessor(config *params.ChainConfig, bc *BlockChain, engine consen
 // Process returns the receipts and logs accumulated during the process and
 // returns the amount of gas that was used in the process. If any of the
 // transactions failed to execute due to insufficient gas it will return an error.
-func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg vm.Config) (types.Receipts, []*types.Log, uint64, error) {
+func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg vm.Config) (*state.StateDB, types.Receipts, []*types.Log, uint64, error) {
 	var (
 		usedGas = new(uint64)
 		header  = block.Header()
 		allLogs []*types.Log
 		gp      = new(GasPool).AddGas(block.GasLimit())
 	)
+	signer := types.MakeSigner(p.bc.chainConfig, block.Number())
+	statedb.TryPreload(block, signer)
 	var receipts = make([]*types.Receipt, 0)
 	// Mutate the block and state according to any hard-fork specs
 	if p.config.DAOForkSupport && p.config.DAOForkBlock != nil && p.config.DAOForkBlock.Cmp(block.Number()) == 0 {
@@ -79,11 +395,10 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg
 	commonTxs := make([]*types.Transaction, 0, len(block.Transactions()))
 	// usually do have two tx, one for validator set contract, another for system reward contract.
 	systemTxs := make([]*types.Transaction, 0, 2)
-	signer := types.MakeSigner(p.config, header.Number)
 	for i, tx := range block.Transactions() {
 		if isPoSA {
 			if isSystemTx, err := posa.IsSystemTransaction(tx, block.Header()); err != nil {
-				return nil, nil, 0, err
+				return statedb, nil, nil, 0, err
 			} else if isSystemTx {
 				systemTxs = append(systemTxs, tx)
 				continue
@@ -92,12 +407,12 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg
 
 		msg, err := tx.AsMessage(signer)
 		if err != nil {
-			return nil, nil, 0, err
+			return statedb, nil, nil, 0, err
 		}
 		statedb.Prepare(tx.Hash(), block.Hash(), i)
 		receipt, err := applyTransaction(msg, p.config, p.bc, nil, gp, statedb, header, tx, usedGas, vmenv)
 		if err != nil {
-			return nil, nil, 0, fmt.Errorf("could not apply tx %d [%v]: %w", i, tx.Hash().Hex(), err)
+			return statedb, nil, nil, 0, fmt.Errorf("could not apply tx %d [%v]: %w", i, tx.Hash().Hex(), err)
 		}
 
 		commonTxs = append(commonTxs, tx)
@@ -107,13 +422,13 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg
 	// Finalize the block, applying any consensus engine specific extras (e.g. block rewards)
 	err := p.engine.Finalize(p.bc, header, statedb, &commonTxs, block.Uncles(), &receipts, &systemTxs, usedGas)
 	if err != nil {
-		return receipts, allLogs, *usedGas, err
+		return statedb, receipts, allLogs, *usedGas, err
 	}
 	for _, receipt := range receipts {
 		allLogs = append(allLogs, receipt.Logs...)
 	}
 
-	return receipts, allLogs, *usedGas, nil
+	return statedb, receipts, allLogs, *usedGas, nil
 }
 
 func applyTransaction(msg types.Message, config *params.ChainConfig, bc ChainContext, author *common.Address, gp *GasPool, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, evm *vm.EVM) (*types.Receipt, error) {

+ 1 - 1
core/types.go

@@ -47,5 +47,5 @@ type Processor interface {
 	// Process processes the state changes according to the Ethereum rules by running
 	// the transaction messages using the statedb and applying any rewards to both
 	// the processor (coinbase) and any included uncles.
-	Process(block *types.Block, statedb *state.StateDB, cfg vm.Config) (types.Receipts, []*types.Log, uint64, error)
+	Process(block *types.Block, statedb *state.StateDB, cfg vm.Config) (*state.StateDB, types.Receipts, []*types.Log, uint64, error)
 }

+ 88 - 1
core/types/block.go

@@ -19,6 +19,7 @@ package types
 
 import (
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
 	"math/big"
@@ -28,11 +29,14 @@ import (
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/common/hexutil"
+	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/rlp"
 )
 
 var (
-	EmptyRootHash  = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+	EmptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
+	EmptyCodeHash = crypto.Keccak256(nil)
+
 	EmptyUncleHash = rlpHash([]*Header(nil))
 )
 
@@ -366,3 +370,86 @@ func (b *Block) Hash() common.Hash {
 }
 
 type Blocks []*Block
+
+type DiffLayer struct {
+	BlockHash common.Hash
+	Number    uint64
+	Receipts  Receipts // Receipts are duplicated stored to simplify the logic
+	Codes     []DiffCode
+	Destructs []common.Address
+	Accounts  []DiffAccount
+	Storages  []DiffStorage
+
+	DiffHash common.Hash
+}
+
+type extDiffLayer struct {
+	BlockHash common.Hash
+	Number    uint64
+	Receipts  []*ReceiptForStorage // Receipts are duplicated stored to simplify the logic
+	Codes     []DiffCode
+	Destructs []common.Address
+	Accounts  []DiffAccount
+	Storages  []DiffStorage
+}
+
+// DecodeRLP decodes the Ethereum
+func (d *DiffLayer) DecodeRLP(s *rlp.Stream) error {
+	var ed extDiffLayer
+	if err := s.Decode(&ed); err != nil {
+		return err
+	}
+	d.BlockHash, d.Number, d.Codes, d.Destructs, d.Accounts, d.Storages =
+		ed.BlockHash, ed.Number, ed.Codes, ed.Destructs, ed.Accounts, ed.Storages
+
+	d.Receipts = make([]*Receipt, len(ed.Receipts))
+	for i, storageReceipt := range ed.Receipts {
+		d.Receipts[i] = (*Receipt)(storageReceipt)
+	}
+	return nil
+}
+
+// EncodeRLP serializes b into the Ethereum RLP block format.
+func (d *DiffLayer) EncodeRLP(w io.Writer) error {
+	storageReceipts := make([]*ReceiptForStorage, len(d.Receipts))
+	for i, receipt := range d.Receipts {
+		storageReceipts[i] = (*ReceiptForStorage)(receipt)
+	}
+	return rlp.Encode(w, extDiffLayer{
+		BlockHash: d.BlockHash,
+		Number:    d.Number,
+		Receipts:  storageReceipts,
+		Codes:     d.Codes,
+		Destructs: d.Destructs,
+		Accounts:  d.Accounts,
+		Storages:  d.Storages,
+	})
+}
+
+func (d *DiffLayer) Validate() error {
+	if d.BlockHash == (common.Hash{}) {
+		return errors.New("blockHash can't be empty")
+	}
+	for _, storage := range d.Storages {
+		if len(storage.Keys) != len(storage.Vals) {
+			return errors.New("the length of keys and values mismatch in storage")
+		}
+	}
+	return nil
+}
+
+type DiffCode struct {
+	Hash common.Hash
+	Code []byte
+}
+
+type DiffAccount struct {
+	Account common.Address
+	Blob    []byte
+}
+
+type DiffStorage struct {
+	Account common.Address
+	Keys    []string
+	Vals    [][]byte
+}

+ 1 - 1
core/vm/contracts_lightclient_test.go

@@ -10,7 +10,7 @@ import (
 )
 
 const (
-	testHeight            uint64 = 66848226
+	testHeight uint64 = 66848226
 )
 
 func TestTmHeaderValidateAndMerkleProofValidate(t *testing.T) {

+ 1 - 1
core/vm/lightclient/types.go

@@ -103,7 +103,7 @@ func (cs ConsensusState) EncodeConsensusState() ([]byte, error) {
 	copy(encodingBytes[pos:pos+chainIDLength], cs.ChainID)
 	pos += chainIDLength
 
-	binary.BigEndian.PutUint64(encodingBytes[pos:pos+heightLength], uint64(cs.Height))
+	binary.BigEndian.PutUint64(encodingBytes[pos:pos+heightLength], cs.Height)
 	pos += heightLength
 
 	copy(encodingBytes[pos:pos+appHashLength], cs.AppHash)

+ 15 - 3
eth/backend.go

@@ -42,6 +42,7 @@ import (
 	"github.com/ethereum/go-ethereum/eth/ethconfig"
 	"github.com/ethereum/go-ethereum/eth/filters"
 	"github.com/ethereum/go-ethereum/eth/gasprice"
+	"github.com/ethereum/go-ethereum/eth/protocols/diff"
 	"github.com/ethereum/go-ethereum/eth/protocols/eth"
 	"github.com/ethereum/go-ethereum/eth/protocols/snap"
 	"github.com/ethereum/go-ethereum/ethdb"
@@ -128,7 +129,8 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) {
 	ethashConfig.NotifyFull = config.Miner.NotifyFull
 
 	// Assemble the Ethereum object
-	chainDb, err := stack.OpenDatabaseWithFreezer("chaindata", config.DatabaseCache, config.DatabaseHandles, config.DatabaseFreezer, "eth/db/chaindata/", false)
+	chainDb, err := stack.OpenAndMergeDatabase("chaindata", config.DatabaseCache, config.DatabaseHandles,
+		config.DatabaseFreezer, config.DatabaseDiff, "eth/db/chaindata/", false, config.PersistDiff)
 	if err != nil {
 		return nil, err
 	}
@@ -197,7 +199,14 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) {
 			Preimages:          config.Preimages,
 		}
 	)
-	eth.blockchain, err = core.NewBlockChain(chainDb, cacheConfig, chainConfig, eth.engine, vmConfig, eth.shouldPreserve, &config.TxLookupLimit)
+	bcOps := make([]core.BlockChainOption, 0)
+	if config.DiffSync {
+		bcOps = append(bcOps, core.EnableLightProcessor)
+	}
+	if config.PersistDiff {
+		bcOps = append(bcOps, core.EnablePersistDiff(config.DiffBlock))
+	}
+	eth.blockchain, err = core.NewBlockChain(chainDb, cacheConfig, chainConfig, eth.engine, vmConfig, eth.shouldPreserve, &config.TxLookupLimit, bcOps...)
 	if err != nil {
 		return nil, err
 	}
@@ -232,6 +241,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) {
 		Checkpoint:      checkpoint,
 		Whitelist:       config.Whitelist,
 		DirectBroadcast: config.DirectBroadcast,
+		DiffSync:        config.DiffSync,
 	}); err != nil {
 		return nil, err
 	}
@@ -534,9 +544,11 @@ func (s *Ethereum) BloomIndexer() *core.ChainIndexer   { return s.bloomIndexer }
 // network protocols to start.
 func (s *Ethereum) Protocols() []p2p.Protocol {
 	protos := eth.MakeProtocols((*ethHandler)(s.handler), s.networkID, s.ethDialCandidates)
-	if s.config.SnapshotCache > 0 {
+	if !s.config.DisableSnapProtocol && s.config.SnapshotCache > 0 {
 		protos = append(protos, snap.MakeProtocols((*snapHandler)(s.handler), s.snapDialCandidates)...)
 	}
+	// diff protocol can still open without snap protocol
+	protos = append(protos, diff.MakeProtocols((*diffHandler)(s.handler), s.snapDialCandidates)...)
 	return protos
 }
 

+ 49 - 7
eth/downloader/downloader.go

@@ -161,10 +161,10 @@ type Downloader struct {
 	quitLock sync.Mutex    // Lock to prevent double closes
 
 	// Testing hooks
-	syncInitHook     func(uint64, uint64)  // Method to call upon initiating a new sync run
-	bodyFetchHook    func([]*types.Header) // Method to call upon starting a block body fetch
-	receiptFetchHook func([]*types.Header) // Method to call upon starting a receipt fetch
-	chainInsertHook  func([]*fetchResult)  // Method to call upon inserting a chain of blocks (possibly in multiple invocations)
+	syncInitHook     func(uint64, uint64)                  // Method to call upon initiating a new sync run
+	bodyFetchHook    func([]*types.Header, ...interface{}) // Method to call upon starting a block body fetch
+	receiptFetchHook func([]*types.Header, ...interface{}) // Method to call upon starting a receipt fetch
+	chainInsertHook  func([]*fetchResult)                  // Method to call upon inserting a chain of blocks (possibly in multiple invocations)
 }
 
 // LightChain encapsulates functions required to synchronise a light chain.
@@ -220,8 +220,45 @@ type BlockChain interface {
 	Snapshots() *snapshot.Tree
 }
 
+type DownloadOption func(downloader *Downloader) *Downloader
+
+type IDiffPeer interface {
+	RequestDiffLayers([]common.Hash) error
+}
+
+type IPeerSet interface {
+	GetDiffPeer(string) IDiffPeer
+}
+
+func EnableDiffFetchOp(peers IPeerSet) DownloadOption {
+	return func(dl *Downloader) *Downloader {
+		var hook = func(headers []*types.Header, args ...interface{}) {
+			if len(args) < 2 {
+				return
+			}
+			peerID, ok := args[1].(string)
+			if !ok {
+				return
+			}
+			mode, ok := args[0].(SyncMode)
+			if !ok {
+				return
+			}
+			if ep := peers.GetDiffPeer(peerID); mode == FullSync && ep != nil {
+				hashes := make([]common.Hash, 0, len(headers))
+				for _, header := range headers {
+					hashes = append(hashes, header.Hash())
+				}
+				ep.RequestDiffLayers(hashes)
+			}
+		}
+		dl.bodyFetchHook = hook
+		return dl
+	}
+}
+
 // New creates a new downloader to fetch hashes and blocks from remote peers.
-func New(checkpoint uint64, stateDb ethdb.Database, stateBloom *trie.SyncBloom, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn) *Downloader {
+func New(checkpoint uint64, stateDb ethdb.Database, stateBloom *trie.SyncBloom, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn, options ...DownloadOption) *Downloader {
 	if lightchain == nil {
 		lightchain = chain
 	}
@@ -252,6 +289,11 @@ func New(checkpoint uint64, stateDb ethdb.Database, stateBloom *trie.SyncBloom,
 		},
 		trackStateReq: make(chan *stateReq),
 	}
+	for _, option := range options {
+		if dl != nil {
+			dl = option(dl)
+		}
+	}
 	go dl.qosTuner()
 	go dl.stateFetcher()
 	return dl
@@ -1363,7 +1405,7 @@ func (d *Downloader) fetchReceipts(from uint64) error {
 //  - kind:        textual label of the type being downloaded to display in log messages
 func (d *Downloader) fetchParts(deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool,
 	expire func() map[string]int, pending func() int, inFlight func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, bool),
-	fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int,
+	fetchHook func([]*types.Header, ...interface{}), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int,
 	idle func() ([]*peerConnection, int), setIdle func(*peerConnection, int, time.Time), kind string) error {
 
 	// Create a ticker to detect expired retrieval tasks
@@ -1512,7 +1554,7 @@ func (d *Downloader) fetchParts(deliveryCh chan dataPack, deliver func(dataPack)
 				}
 				// Fetch the chunk and make sure any errors return the hashes to the queue
 				if fetchHook != nil {
-					fetchHook(request.Headers)
+					fetchHook(request.Headers, d.getMode(), peer.id)
 				}
 				if err := fetch(peer, request); err != nil {
 					// Although we could try and make an attempt to fix this, this error really

+ 2 - 2
eth/downloader/downloader_test.go

@@ -921,10 +921,10 @@ func testEmptyShortCircuit(t *testing.T, protocol uint, mode SyncMode) {
 
 	// Instrument the downloader to signal body requests
 	bodiesHave, receiptsHave := int32(0), int32(0)
-	tester.downloader.bodyFetchHook = func(headers []*types.Header) {
+	tester.downloader.bodyFetchHook = func(headers []*types.Header, _ ...interface{}) {
 		atomic.AddInt32(&bodiesHave, int32(len(headers)))
 	}
-	tester.downloader.receiptFetchHook = func(headers []*types.Header) {
+	tester.downloader.receiptFetchHook = func(headers []*types.Header, _ ...interface{}) {
 		atomic.AddInt32(&receiptsHave, int32(len(headers)))
 	}
 	// Synchronise with the peer and make sure all blocks were retrieved

+ 9 - 3
eth/ethconfig/config.go

@@ -80,6 +80,7 @@ var Defaults = Config{
 	TrieTimeout:             60 * time.Minute,
 	TriesInMemory:           128,
 	SnapshotCache:           102,
+	DiffBlock:               uint64(864000),
 	Miner: miner.Config{
 		GasFloor:      8000000,
 		GasCeil:       8000000,
@@ -131,9 +132,11 @@ type Config struct {
 	EthDiscoveryURLs  []string
 	SnapDiscoveryURLs []string
 
-	NoPruning       bool // Whether to disable pruning and flush everything to disk
-	DirectBroadcast bool
-	RangeLimit      bool
+	NoPruning           bool // Whether to disable pruning and flush everything to disk
+	DirectBroadcast     bool
+	DisableSnapProtocol bool //Whether disable snap protocol
+	DiffSync            bool // Whether support diff sync
+	RangeLimit          bool
 
 	TxLookupLimit uint64 `toml:",omitempty"` // The maximum number of blocks from head whose tx indices are reserved.
 
@@ -159,6 +162,9 @@ type Config struct {
 	DatabaseHandles    int  `toml:"-"`
 	DatabaseCache      int
 	DatabaseFreezer    string
+	DatabaseDiff       string
+	PersistDiff        bool
+	DiffBlock          uint64
 
 	TrieCleanCache          int
 	TrieCleanCacheJournal   string        `toml:",omitempty"` // Disk journal directory for trie cache to survive node restarts

+ 18 - 0
eth/ethconfig/gen_config.go

@@ -40,6 +40,7 @@ func (c Config) MarshalTOML() (interface{}, error) {
 		DatabaseHandles         int                    `toml:"-"`
 		DatabaseCache           int
 		DatabaseFreezer         string
+		DatabaseDiff            string
 		TrieCleanCache          int
 		TrieCleanCacheJournal   string        `toml:",omitempty"`
 		TrieCleanCacheRejournal time.Duration `toml:",omitempty"`
@@ -48,6 +49,8 @@ func (c Config) MarshalTOML() (interface{}, error) {
 		TriesInMemory           uint64 `toml:",omitempty"`
 		SnapshotCache           int
 		Preimages               bool
+		PersistDiff             bool
+		DiffBlock               uint64 `toml:",omitempty"`
 		Miner                   miner.Config
 		Ethash                  ethash.Config
 		TxPool                  core.TxPoolConfig
@@ -84,6 +87,7 @@ func (c Config) MarshalTOML() (interface{}, error) {
 	enc.DatabaseHandles = c.DatabaseHandles
 	enc.DatabaseCache = c.DatabaseCache
 	enc.DatabaseFreezer = c.DatabaseFreezer
+	enc.DatabaseDiff = c.DatabaseDiff
 	enc.TrieCleanCache = c.TrieCleanCache
 	enc.TrieCleanCacheJournal = c.TrieCleanCacheJournal
 	enc.TrieCleanCacheRejournal = c.TrieCleanCacheRejournal
@@ -92,6 +96,8 @@ func (c Config) MarshalTOML() (interface{}, error) {
 	enc.TriesInMemory = c.TriesInMemory
 	enc.SnapshotCache = c.SnapshotCache
 	enc.Preimages = c.Preimages
+	enc.PersistDiff = c.PersistDiff
+	enc.DiffBlock = c.DiffBlock
 	enc.Miner = c.Miner
 	enc.Ethash = c.Ethash
 	enc.TxPool = c.TxPool
@@ -133,6 +139,9 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error {
 		DatabaseHandles         *int                   `toml:"-"`
 		DatabaseCache           *int
 		DatabaseFreezer         *string
+		DatabaseDiff            *string
+		PersistDiff             *bool
+		DiffBlock               *uint64 `toml:",omitempty"`
 		TrieCleanCache          *int
 		TrieCleanCacheJournal   *string        `toml:",omitempty"`
 		TrieCleanCacheRejournal *time.Duration `toml:",omitempty"`
@@ -224,6 +233,15 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error {
 	if dec.DatabaseFreezer != nil {
 		c.DatabaseFreezer = *dec.DatabaseFreezer
 	}
+	if dec.DatabaseDiff != nil {
+		c.DatabaseDiff = *dec.DatabaseDiff
+	}
+	if dec.PersistDiff != nil {
+		c.PersistDiff = *dec.PersistDiff
+	}
+	if dec.DiffBlock != nil {
+		c.DiffBlock = *dec.DiffBlock
+	}
 	if dec.TrieCleanCache != nil {
 		c.TrieCleanCache = *dec.TrieCleanCache
 	}

+ 12 - 1
eth/fetcher/block_fetcher.go

@@ -82,6 +82,9 @@ type headerRequesterFn func(common.Hash) error
 // bodyRequesterFn is a callback type for sending a body retrieval request.
 type bodyRequesterFn func([]common.Hash) error
 
+// DiffRequesterFn is a callback type for sending a diff layer retrieval request.
+type DiffRequesterFn func([]common.Hash) error
+
 // headerVerifierFn is a callback type to verify a block's header for fast propagation.
 type headerVerifierFn func(header *types.Header) error
 
@@ -112,6 +115,8 @@ type blockAnnounce struct {
 
 	fetchHeader headerRequesterFn // Fetcher function to retrieve the header of an announced block
 	fetchBodies bodyRequesterFn   // Fetcher function to retrieve the body of an announced block
+	fetchDiffs  DiffRequesterFn   // Fetcher function to retrieve the diff layer of an announced block
+
 }
 
 // headerFilterTask represents a batch of headers needing fetcher filtering.
@@ -246,7 +251,7 @@ func (f *BlockFetcher) Stop() {
 // Notify announces the fetcher of the potential availability of a new block in
 // the network.
 func (f *BlockFetcher) Notify(peer string, hash common.Hash, number uint64, time time.Time,
-	headerFetcher headerRequesterFn, bodyFetcher bodyRequesterFn) error {
+	headerFetcher headerRequesterFn, bodyFetcher bodyRequesterFn, diffFetcher DiffRequesterFn) error {
 	block := &blockAnnounce{
 		hash:        hash,
 		number:      number,
@@ -254,6 +259,7 @@ func (f *BlockFetcher) Notify(peer string, hash common.Hash, number uint64, time
 		origin:      peer,
 		fetchHeader: headerFetcher,
 		fetchBodies: bodyFetcher,
+		fetchDiffs:  diffFetcher,
 	}
 	select {
 	case f.notify <- block:
@@ -481,10 +487,15 @@ func (f *BlockFetcher) loop() {
 
 				// Create a closure of the fetch and schedule in on a new thread
 				fetchHeader, hashes := f.fetching[hashes[0]].fetchHeader, hashes
+				fetchDiff := f.fetching[hashes[0]].fetchDiffs
+
 				gopool.Submit(func() {
 					if f.fetchingHook != nil {
 						f.fetchingHook(hashes)
 					}
+					if fetchDiff != nil {
+						fetchDiff(hashes)
+					}
 					for _, hash := range hashes {
 						headerFetchMeter.Mark(1)
 						fetchHeader(hash) // Suboptimal, but protocol doesn't allow batch header retrievals

+ 18 - 18
eth/fetcher/block_fetcher_test.go

@@ -343,7 +343,7 @@ func testSequentialAnnouncements(t *testing.T, light bool) {
 		}
 	}
 	for i := len(hashes) - 2; i >= 0; i-- {
-		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 		verifyImportEvent(t, imported, true)
 	}
 	verifyImportDone(t, imported)
@@ -392,9 +392,9 @@ func testConcurrentAnnouncements(t *testing.T, light bool) {
 		}
 	}
 	for i := len(hashes) - 2; i >= 0; i-- {
-		tester.fetcher.Notify("first", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), firstHeaderWrapper, firstBodyFetcher)
-		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout+time.Millisecond), secondHeaderWrapper, secondBodyFetcher)
-		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout-time.Millisecond), secondHeaderWrapper, secondBodyFetcher)
+		tester.fetcher.Notify("first", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), firstHeaderWrapper, firstBodyFetcher, nil)
+		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout+time.Millisecond), secondHeaderWrapper, secondBodyFetcher, nil)
+		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout-time.Millisecond), secondHeaderWrapper, secondBodyFetcher, nil)
 		verifyImportEvent(t, imported, true)
 	}
 	verifyImportDone(t, imported)
@@ -441,7 +441,7 @@ func testOverlappingAnnouncements(t *testing.T, light bool) {
 	}
 
 	for i := len(hashes) - 2; i >= 0; i-- {
-		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 		select {
 		case <-imported:
 		case <-time.After(time.Second):
@@ -488,7 +488,7 @@ func testPendingDeduplication(t *testing.T, light bool) {
 	}
 	// Announce the same block many times until it's fetched (wait for any pending ops)
 	for checkNonExist() {
-		tester.fetcher.Notify("repeater", hashes[0], 1, time.Now().Add(-arriveTimeout), headerWrapper, bodyFetcher)
+		tester.fetcher.Notify("repeater", hashes[0], 1, time.Now().Add(-arriveTimeout), headerWrapper, bodyFetcher, nil)
 		time.Sleep(time.Millisecond)
 	}
 	time.Sleep(delay)
@@ -532,12 +532,12 @@ func testRandomArrivalImport(t *testing.T, light bool) {
 	}
 	for i := len(hashes) - 1; i >= 0; i-- {
 		if i != skip {
-			tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+			tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 			time.Sleep(time.Millisecond)
 		}
 	}
 	// Finally announce the skipped entry and check full import
-	tester.fetcher.Notify("valid", hashes[skip], uint64(len(hashes)-skip-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+	tester.fetcher.Notify("valid", hashes[skip], uint64(len(hashes)-skip-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 	verifyImportCount(t, imported, len(hashes)-1)
 	verifyChainHeight(t, tester, uint64(len(hashes)-1))
 }
@@ -560,7 +560,7 @@ func TestQueueGapFill(t *testing.T) {
 
 	for i := len(hashes) - 1; i >= 0; i-- {
 		if i != skip {
-			tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+			tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 			time.Sleep(time.Millisecond)
 		}
 	}
@@ -593,7 +593,7 @@ func TestImportDeduplication(t *testing.T) {
 	tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block }
 
 	// Announce the duplicating block, wait for retrieval, and also propagate directly
-	tester.fetcher.Notify("valid", hashes[0], 1, time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+	tester.fetcher.Notify("valid", hashes[0], 1, time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 	<-fetching
 
 	tester.fetcher.Enqueue("valid", blocks[hashes[0]])
@@ -669,14 +669,14 @@ func testDistantAnnouncementDiscarding(t *testing.T, light bool) {
 	tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- struct{}{} }
 
 	// Ensure that a block with a lower number than the threshold is discarded
-	tester.fetcher.Notify("lower", hashes[low], blocks[hashes[low]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+	tester.fetcher.Notify("lower", hashes[low], blocks[hashes[low]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 	select {
 	case <-time.After(50 * time.Millisecond):
 	case <-fetching:
 		t.Fatalf("fetcher requested stale header")
 	}
 	// Ensure that a block with a higher number than the threshold is discarded
-	tester.fetcher.Notify("higher", hashes[high], blocks[hashes[high]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+	tester.fetcher.Notify("higher", hashes[high], blocks[hashes[high]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 	select {
 	case <-time.After(50 * time.Millisecond):
 	case <-fetching:
@@ -712,7 +712,7 @@ func testInvalidNumberAnnouncement(t *testing.T, light bool) {
 		}
 	}
 	// Announce a block with a bad number, check for immediate drop
-	tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), badHeaderFetcher, badBodyFetcher)
+	tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), badHeaderFetcher, badBodyFetcher, nil)
 	verifyImportEvent(t, imported, false)
 
 	tester.lock.RLock()
@@ -726,7 +726,7 @@ func testInvalidNumberAnnouncement(t *testing.T, light bool) {
 	goodHeaderFetcher := tester.makeHeaderFetcher("good", blocks, -gatherSlack)
 	goodBodyFetcher := tester.makeBodyFetcher("good", blocks, 0)
 	// Make sure a good announcement passes without a drop
-	tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), goodHeaderFetcher, goodBodyFetcher)
+	tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), goodHeaderFetcher, goodBodyFetcher, nil)
 	verifyImportEvent(t, imported, true)
 
 	tester.lock.RLock()
@@ -765,7 +765,7 @@ func TestEmptyBlockShortCircuit(t *testing.T) {
 	}
 	// Iteratively announce blocks until all are imported
 	for i := len(hashes) - 2; i >= 0; i-- {
-		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher, nil)
 
 		// All announces should fetch the header
 		verifyFetchingEvent(t, fetching, true)
@@ -808,9 +808,9 @@ func TestHashMemoryExhaustionAttack(t *testing.T) {
 	// Feed the tester a huge hashset from the attacker, and a limited from the valid peer
 	for i := 0; i < len(attack); i++ {
 		if i < maxQueueDist {
-			tester.fetcher.Notify("valid", hashes[len(hashes)-2-i], uint64(i+1), time.Now(), validHeaderFetcher, validBodyFetcher)
+			tester.fetcher.Notify("valid", hashes[len(hashes)-2-i], uint64(i+1), time.Now(), validHeaderFetcher, validBodyFetcher, nil)
 		}
-		tester.fetcher.Notify("attacker", attack[i], 1 /* don't distance drop */, time.Now(), attackerHeaderFetcher, attackerBodyFetcher)
+		tester.fetcher.Notify("attacker", attack[i], 1 /* don't distance drop */, time.Now(), attackerHeaderFetcher, attackerBodyFetcher, nil)
 	}
 	if count := atomic.LoadInt32(&announces); count != hashLimit+maxQueueDist {
 		t.Fatalf("queued announce count mismatch: have %d, want %d", count, hashLimit+maxQueueDist)
@@ -820,7 +820,7 @@ func TestHashMemoryExhaustionAttack(t *testing.T) {
 
 	// Feed the remaining valid hashes to ensure DOS protection state remains clean
 	for i := len(hashes) - maxQueueDist - 2; i >= 0; i-- {
-		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), validHeaderFetcher, validBodyFetcher)
+		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), validHeaderFetcher, validBodyFetcher, nil)
 		verifyImportEvent(t, imported, true)
 	}
 	verifyImportDone(t, imported)

+ 38 - 3
eth/handler.go

@@ -30,6 +30,7 @@ import (
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/eth/downloader"
 	"github.com/ethereum/go-ethereum/eth/fetcher"
+	"github.com/ethereum/go-ethereum/eth/protocols/diff"
 	"github.com/ethereum/go-ethereum/eth/protocols/eth"
 	"github.com/ethereum/go-ethereum/eth/protocols/snap"
 	"github.com/ethereum/go-ethereum/ethdb"
@@ -37,6 +38,7 @@ import (
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/params"
+	"github.com/ethereum/go-ethereum/rlp"
 	"github.com/ethereum/go-ethereum/trie"
 )
 
@@ -81,6 +83,7 @@ type handlerConfig struct {
 	TxPool          txPool                    // Transaction pool to propagate from
 	Network         uint64                    // Network identifier to adfvertise
 	Sync            downloader.SyncMode       // Whether to fast or full sync
+	DiffSync        bool                      // Whether to diff sync
 	BloomCache      uint64                    // Megabytes to alloc for fast sync bloom
 	EventMux        *event.TypeMux            // Legacy event mux, deprecate for `feed`
 	Checkpoint      *params.TrustedCheckpoint // Hard coded checkpoint for sync challenges
@@ -96,6 +99,7 @@ type handler struct {
 	snapSync        uint32 // Flag whether fast sync should operate on top of the snap protocol
 	acceptTxs       uint32 // Flag whether we're considered synchronised (enables transaction processing)
 	directBroadcast bool
+	diffSync        bool // Flag whether diff sync should operate on top of the diff protocol
 
 	checkpointNumber uint64      // Block number for the sync progress validator to cross reference
 	checkpointHash   common.Hash // Block hash for the sync progress validator to cross reference
@@ -143,6 +147,7 @@ func newHandler(config *handlerConfig) (*handler, error) {
 		peers:           newPeerSet(),
 		whitelist:       config.Whitelist,
 		directBroadcast: config.DirectBroadcast,
+		diffSync:        config.DiffSync,
 		txsyncCh:        make(chan *txsync),
 		quitSync:        make(chan struct{}),
 	}
@@ -187,7 +192,11 @@ func newHandler(config *handlerConfig) (*handler, error) {
 	if atomic.LoadUint32(&h.fastSync) == 1 && atomic.LoadUint32(&h.snapSync) == 0 {
 		h.stateBloom = trie.NewSyncBloom(config.BloomCache, config.Database)
 	}
-	h.downloader = downloader.New(h.checkpointNumber, config.Database, h.stateBloom, h.eventMux, h.chain, nil, h.removePeer)
+	var downloadOptions []downloader.DownloadOption
+	if h.diffSync {
+		downloadOptions = append(downloadOptions, downloader.EnableDiffFetchOp(h.peers))
+	}
+	h.downloader = downloader.New(h.checkpointNumber, config.Database, h.stateBloom, h.eventMux, h.chain, nil, h.removePeer, downloadOptions...)
 
 	// Construct the fetcher (short sync)
 	validator := func(header *types.Header) error {
@@ -246,6 +255,11 @@ func (h *handler) runEthPeer(peer *eth.Peer, handler eth.Handler) error {
 		peer.Log().Error("Snapshot extension barrier failed", "err", err)
 		return err
 	}
+	diff, err := h.peers.waitDiffExtension(peer)
+	if err != nil {
+		peer.Log().Error("Diff extension barrier failed", "err", err)
+		return err
+	}
 	// TODO(karalabe): Not sure why this is needed
 	if !h.chainSync.handlePeerEvent(peer) {
 		return p2p.DiscQuitting
@@ -286,7 +300,7 @@ func (h *handler) runEthPeer(peer *eth.Peer, handler eth.Handler) error {
 	peer.Log().Debug("Ethereum peer connected", "name", peer.Name())
 
 	// Register the peer locally
-	if err := h.peers.registerPeer(peer, snap); err != nil {
+	if err := h.peers.registerPeer(peer, snap, diff); err != nil {
 		peer.Log().Error("Ethereum peer registration failed", "err", err)
 		return err
 	}
@@ -357,6 +371,21 @@ func (h *handler) runSnapExtension(peer *snap.Peer, handler snap.Handler) error
 	return handler(peer)
 }
 
+// runDiffExtension registers a `diff` peer into the joint eth/diff peerset and
+// starts handling inbound messages. As `diff` is only a satellite protocol to
+// `eth`, all subsystem registrations and lifecycle management will be done by
+// the main `eth` handler to prevent strange races.
+func (h *handler) runDiffExtension(peer *diff.Peer, handler diff.Handler) error {
+	h.peerWG.Add(1)
+	defer h.peerWG.Done()
+
+	if err := h.peers.registerDiffExtension(peer); err != nil {
+		peer.Log().Error("Diff extension registration failed", "err", err)
+		return err
+	}
+	return handler(peer)
+}
+
 // removePeer unregisters a peer from the downloader and fetchers, removes it from
 // the set of tracked peers and closes the network connection to it.
 func (h *handler) removePeer(id string) {
@@ -449,13 +478,19 @@ func (h *handler) BroadcastBlock(block *types.Block, propagate bool) {
 		// Send the block to a subset of our peers
 		var transfer []*ethPeer
 		if h.directBroadcast {
-			transfer = peers[:int(len(peers))]
+			transfer = peers[:]
 		} else {
 			transfer = peers[:int(math.Sqrt(float64(len(peers))))]
 		}
+		diff := h.chain.GetDiffLayerRLP(block.Hash())
 		for _, peer := range transfer {
+			if len(diff) != 0 && peer.diffExt != nil {
+				// difflayer should send before block
+				peer.diffExt.SendDiffLayers([]rlp.RawValue{diff})
+			}
 			peer.AsyncSendNewBlock(block, td)
 		}
+
 		log.Trace("Propagated block", "hash", hash, "recipients", len(transfer), "duration", common.PrettyDuration(time.Since(block.ReceivedAt)))
 		return
 	}

+ 87 - 0
eth/handler_diff.go

@@ -0,0 +1,87 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package eth
+
+import (
+	"fmt"
+
+	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/eth/protocols/diff"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+// diffHandler implements the diff.Backend interface to handle the various network
+// packets that are sent as replies or broadcasts.
+type diffHandler handler
+
+func (h *diffHandler) Chain() *core.BlockChain { return h.chain }
+
+// RunPeer is invoked when a peer joins on the `diff` protocol.
+func (h *diffHandler) RunPeer(peer *diff.Peer, hand diff.Handler) error {
+	if err := peer.Handshake(h.diffSync); err != nil {
+		return err
+	}
+	defer h.chain.RemoveDiffPeer(peer.ID())
+	return (*handler)(h).runDiffExtension(peer, hand)
+}
+
+// PeerInfo retrieves all known `diff` information about a peer.
+func (h *diffHandler) PeerInfo(id enode.ID) interface{} {
+	if p := h.peers.peer(id.String()); p != nil && p.diffExt != nil {
+		return p.diffExt.info()
+	}
+	return nil
+}
+
+// Handle is invoked from a peer's message handler when it receives a new remote
+// message that the handler couldn't consume and serve itself.
+func (h *diffHandler) Handle(peer *diff.Peer, packet diff.Packet) error {
+	// DeliverSnapPacket is invoked from a peer's message handler when it transmits a
+	// data packet for the local node to consume.
+	switch packet := packet.(type) {
+	case *diff.DiffLayersPacket:
+		return h.handleDiffLayerPackage(packet, peer.ID(), false)
+
+	case *diff.FullDiffLayersPacket:
+		return h.handleDiffLayerPackage(&packet.DiffLayersPacket, peer.ID(), true)
+
+	default:
+		return fmt.Errorf("unexpected diff packet type: %T", packet)
+	}
+}
+
+func (h *diffHandler) handleDiffLayerPackage(packet *diff.DiffLayersPacket, pid string, fulfilled bool) error {
+	diffs, err := packet.Unpack()
+
+	if err != nil {
+		return err
+	}
+	for _, d := range diffs {
+		if d != nil {
+			if err := d.Validate(); err != nil {
+				return err
+			}
+		}
+	}
+	for _, diff := range diffs {
+		err := h.chain.HandleDiffLayer(diff, pid, fulfilled)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}

+ 203 - 0
eth/handler_diff_test.go

@@ -0,0 +1,203 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package eth
+
+import (
+	"crypto/rand"
+	"math/big"
+	"testing"
+	"time"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/consensus/ethash"
+	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/core/vm"
+	"github.com/ethereum/go-ethereum/eth/downloader"
+	"github.com/ethereum/go-ethereum/eth/protocols/diff"
+	"github.com/ethereum/go-ethereum/ethdb"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/params"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+// testBackend is a mock implementation of the live Ethereum message handler. Its
+// purpose is to allow testing the request/reply workflows and wire serialization
+// in the `eth` protocol without actually doing any data processing.
+type testBackend struct {
+	db     ethdb.Database
+	chain  *core.BlockChain
+	txpool *core.TxPool
+
+	handler *handler
+}
+
+// newTestBackend creates an empty chain and wraps it into a mock backend.
+func newTestBackend(blocks int) *testBackend {
+	return newTestBackendWithGenerator(blocks)
+}
+
+// newTestBackend creates a chain with a number of explicitly defined blocks and
+// wraps it into a mock backend.
+func newTestBackendWithGenerator(blocks int) *testBackend {
+	signer := types.HomesteadSigner{}
+	// Create a database pre-initialize with a genesis block
+	db := rawdb.NewMemoryDatabase()
+	(&core.Genesis{
+		Config: params.TestChainConfig,
+		Alloc:  core.GenesisAlloc{testAddr: {Balance: big.NewInt(100000000000000000)}},
+	}).MustCommit(db)
+
+	chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil)
+	generator := func(i int, block *core.BlockGen) {
+		// The chain maker doesn't have access to a chain, so the difficulty will be
+		// lets unset (nil). Set it here to the correct value.
+		block.SetCoinbase(testAddr)
+
+		// We want to simulate an empty middle block, having the same state as the
+		// first one. The last is needs a state change again to force a reorg.
+		tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), common.Address{0x01}, big.NewInt(1), params.TxGas, big.NewInt(1), nil), signer, testKey)
+		if err != nil {
+			panic(err)
+		}
+		block.AddTxWithChain(chain, tx)
+	}
+	bs, _ := core.GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, generator)
+	if _, err := chain.InsertChain(bs); err != nil {
+		panic(err)
+	}
+	txpool := newTestTxPool()
+
+	handler, _ := newHandler(&handlerConfig{
+		Database:   db,
+		Chain:      chain,
+		TxPool:     txpool,
+		Network:    1,
+		Sync:       downloader.FullSync,
+		BloomCache: 1,
+	})
+	handler.Start(100)
+
+	txconfig := core.DefaultTxPoolConfig
+	txconfig.Journal = "" // Don't litter the disk with test journals
+
+	return &testBackend{
+		db:      db,
+		chain:   chain,
+		txpool:  core.NewTxPool(txconfig, params.TestChainConfig, chain),
+		handler: handler,
+	}
+}
+
+// close tears down the transaction pool and chain behind the mock backend.
+func (b *testBackend) close() {
+	b.txpool.Stop()
+	b.chain.Stop()
+	b.handler.Stop()
+}
+
+func (b *testBackend) Chain() *core.BlockChain { return b.chain }
+
+func (b *testBackend) RunPeer(peer *diff.Peer, handler diff.Handler) error {
+	// Normally the backend would do peer mainentance and handshakes. All that
+	// is omitted and we will just give control back to the handler.
+	return handler(peer)
+}
+func (b *testBackend) PeerInfo(enode.ID) interface{} { panic("not implemented") }
+
+func (b *testBackend) Handle(*diff.Peer, diff.Packet) error {
+	panic("data processing tests should be done in the handler package")
+}
+
+type testPeer struct {
+	*diff.Peer
+
+	net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
+	app *p2p.MsgPipeRW    // Application layer reader/writer to simulate the local side
+}
+
+// newTestPeer creates a new peer registered at the given data backend.
+func newTestPeer(name string, version uint, backend *testBackend) (*testPeer, <-chan error) {
+	// Create a message pipe to communicate through
+	app, net := p2p.MsgPipe()
+
+	// Start the peer on a new thread
+	var id enode.ID
+	rand.Read(id[:])
+
+	peer := diff.NewPeer(version, p2p.NewPeer(id, name, nil), net)
+	errc := make(chan error, 1)
+	go func() {
+		errc <- backend.RunPeer(peer, func(peer *diff.Peer) error {
+
+			return diff.Handle((*diffHandler)(backend.handler), peer)
+		})
+	}()
+	return &testPeer{app: app, net: net, Peer: peer}, errc
+}
+
+// close terminates the local side of the peer, notifying the remote protocol
+// manager of termination.
+func (p *testPeer) close() {
+	p.Peer.Close()
+	p.app.Close()
+}
+
+func TestHandleDiffLayer(t *testing.T) {
+	t.Parallel()
+
+	blockNum := 1024
+	waitInterval := 100 * time.Millisecond
+	backend := newTestBackend(blockNum)
+	defer backend.close()
+
+	peer, _ := newTestPeer("peer", diff.Diff1, backend)
+	defer peer.close()
+
+	tests := []struct {
+		DiffLayer *types.DiffLayer
+		Valid     bool
+	}{
+		{DiffLayer: &types.DiffLayer{
+			BlockHash: common.Hash{0x1},
+			Number:    1025,
+		}, Valid: true},
+		{DiffLayer: &types.DiffLayer{
+			BlockHash: common.Hash{0x2},
+			Number:    3073,
+		}, Valid: false},
+		{DiffLayer: &types.DiffLayer{
+			BlockHash: common.Hash{0x3},
+			Number:    500,
+		}, Valid: false},
+	}
+
+	for _, tt := range tests {
+		bz, _ := rlp.EncodeToBytes(tt.DiffLayer)
+
+		p2p.Send(peer.app, diff.DiffLayerMsg, diff.DiffLayersPacket{rlp.RawValue(bz)})
+	}
+	time.Sleep(waitInterval)
+	for idx, tt := range tests {
+		diff := backend.chain.GetUnTrustedDiffLayer(tt.DiffLayer.BlockHash, "")
+		if (tt.Valid && diff == nil) || (!tt.Valid && diff != nil) {
+			t.Errorf("test: %d, diff layer handle failed", idx)
+		}
+	}
+}

+ 11 - 2
eth/handler_eth.go

@@ -26,6 +26,7 @@ import (
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/eth/fetcher"
 	"github.com/ethereum/go-ethereum/eth/protocols/eth"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/p2p/enode"
@@ -98,7 +99,6 @@ func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
 
 	case *eth.PooledTransactionsPacket:
 		return h.txFetcher.Enqueue(peer.ID(), *packet, true)
-
 	default:
 		return fmt.Errorf("unexpected eth packet type: %T", packet)
 	}
@@ -191,8 +191,17 @@ func (h *ethHandler) handleBlockAnnounces(peer *eth.Peer, hashes []common.Hash,
 			unknownNumbers = append(unknownNumbers, numbers[i])
 		}
 	}
+	// self support diff sync
+	var diffFetcher fetcher.DiffRequesterFn
+	if h.diffSync {
+		// the peer support diff protocol
+		if ep := h.peers.peer(peer.ID()); ep != nil && ep.diffExt != nil {
+			diffFetcher = ep.diffExt.RequestDiffLayers
+		}
+	}
+
 	for i := 0; i < len(unknownHashes); i++ {
-		h.blockFetcher.Notify(peer.ID(), unknownHashes[i], unknownNumbers[i], time.Now(), peer.RequestOneHeader, peer.RequestBodies)
+		h.blockFetcher.Notify(peer.ID(), unknownHashes[i], unknownNumbers[i], time.Now(), peer.RequestOneHeader, peer.RequestBodies, diffFetcher)
 	}
 	return nil
 }

+ 22 - 0
eth/peer.go

@@ -21,6 +21,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/ethereum/go-ethereum/eth/protocols/diff"
 	"github.com/ethereum/go-ethereum/eth/protocols/eth"
 	"github.com/ethereum/go-ethereum/eth/protocols/snap"
 )
@@ -37,6 +38,7 @@ type ethPeerInfo struct {
 type ethPeer struct {
 	*eth.Peer
 	snapExt *snapPeer // Satellite `snap` connection
+	diffExt *diffPeer
 
 	syncDrop *time.Timer   // Connection dropper if `eth` sync progress isn't validated in time
 	snapWait chan struct{} // Notification channel for snap connections
@@ -60,11 +62,31 @@ type snapPeerInfo struct {
 	Version uint `json:"version"` // Snapshot protocol version negotiated
 }
 
+// diffPeerInfo represents a short summary of the `diff` sub-protocol metadata known
+// about a connected peer.
+type diffPeerInfo struct {
+	Version  uint `json:"version"` // diff protocol version negotiated
+	DiffSync bool `json:"diff_sync"`
+}
+
 // snapPeer is a wrapper around snap.Peer to maintain a few extra metadata.
 type snapPeer struct {
 	*snap.Peer
 }
 
+// diffPeer is a wrapper around diff.Peer to maintain a few extra metadata.
+type diffPeer struct {
+	*diff.Peer
+}
+
+// info gathers and returns some `diff` protocol metadata known about a peer.
+func (p *diffPeer) info() *diffPeerInfo {
+	return &diffPeerInfo{
+		Version:  p.Version(),
+		DiffSync: p.DiffSync(),
+	}
+}
+
 // info gathers and returns some `snap` protocol metadata known about a peer.
 func (p *snapPeer) info() *snapPeerInfo {
 	return &snapPeerInfo{

+ 86 - 1
eth/peerset.go

@@ -22,6 +22,8 @@ import (
 	"sync"
 
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/eth/downloader"
+	"github.com/ethereum/go-ethereum/eth/protocols/diff"
 	"github.com/ethereum/go-ethereum/eth/protocols/eth"
 	"github.com/ethereum/go-ethereum/eth/protocols/snap"
 	"github.com/ethereum/go-ethereum/p2p"
@@ -43,6 +45,10 @@ var (
 	// errSnapWithoutEth is returned if a peer attempts to connect only on the
 	// snap protocol without advertizing the eth main protocol.
 	errSnapWithoutEth = errors.New("peer connected on snap without compatible eth support")
+
+	// errDiffWithoutEth is returned if a peer attempts to connect only on the
+	// diff protocol without advertizing the eth main protocol.
+	errDiffWithoutEth = errors.New("peer connected on diff without compatible eth support")
 )
 
 // peerSet represents the collection of active peers currently participating in
@@ -54,6 +60,9 @@ type peerSet struct {
 	snapWait map[string]chan *snap.Peer // Peers connected on `eth` waiting for their snap extension
 	snapPend map[string]*snap.Peer      // Peers connected on the `snap` protocol, but not yet on `eth`
 
+	diffWait map[string]chan *diff.Peer // Peers connected on `eth` waiting for their diff extension
+	diffPend map[string]*diff.Peer      // Peers connected on the `diff` protocol, but not yet on `eth`
+
 	lock   sync.RWMutex
 	closed bool
 }
@@ -64,6 +73,8 @@ func newPeerSet() *peerSet {
 		peers:    make(map[string]*ethPeer),
 		snapWait: make(map[string]chan *snap.Peer),
 		snapPend: make(map[string]*snap.Peer),
+		diffWait: make(map[string]chan *diff.Peer),
+		diffPend: make(map[string]*diff.Peer),
 	}
 }
 
@@ -97,6 +108,36 @@ func (ps *peerSet) registerSnapExtension(peer *snap.Peer) error {
 	return nil
 }
 
+// registerDiffExtension unblocks an already connected `eth` peer waiting for its
+// `diff` extension, or if no such peer exists, tracks the extension for the time
+// being until the `eth` main protocol starts looking for it.
+func (ps *peerSet) registerDiffExtension(peer *diff.Peer) error {
+	// Reject the peer if it advertises `diff` without `eth` as `diff` is only a
+	// satellite protocol meaningful with the chain selection of `eth`
+	if !peer.RunningCap(eth.ProtocolName, eth.ProtocolVersions) {
+		return errDiffWithoutEth
+	}
+	// Ensure nobody can double connect
+	ps.lock.Lock()
+	defer ps.lock.Unlock()
+
+	id := peer.ID()
+	if _, ok := ps.peers[id]; ok {
+		return errPeerAlreadyRegistered // avoid connections with the same id as existing ones
+	}
+	if _, ok := ps.diffPend[id]; ok {
+		return errPeerAlreadyRegistered // avoid connections with the same id as pending ones
+	}
+	// Inject the peer into an `eth` counterpart is available, otherwise save for later
+	if wait, ok := ps.diffWait[id]; ok {
+		delete(ps.diffWait, id)
+		wait <- peer
+		return nil
+	}
+	ps.diffPend[id] = peer
+	return nil
+}
+
 // waitExtensions blocks until all satellite protocols are connected and tracked
 // by the peerset.
 func (ps *peerSet) waitSnapExtension(peer *eth.Peer) (*snap.Peer, error) {
@@ -131,9 +172,50 @@ func (ps *peerSet) waitSnapExtension(peer *eth.Peer) (*snap.Peer, error) {
 	return <-wait, nil
 }
 
+// waitDiffExtension blocks until all satellite protocols are connected and tracked
+// by the peerset.
+func (ps *peerSet) waitDiffExtension(peer *eth.Peer) (*diff.Peer, error) {
+	// If the peer does not support a compatible `diff`, don't wait
+	if !peer.RunningCap(diff.ProtocolName, diff.ProtocolVersions) {
+		return nil, nil
+	}
+	// Ensure nobody can double connect
+	ps.lock.Lock()
+
+	id := peer.ID()
+	if _, ok := ps.peers[id]; ok {
+		ps.lock.Unlock()
+		return nil, errPeerAlreadyRegistered // avoid connections with the same id as existing ones
+	}
+	if _, ok := ps.diffWait[id]; ok {
+		ps.lock.Unlock()
+		return nil, errPeerAlreadyRegistered // avoid connections with the same id as pending ones
+	}
+	// If `diff` already connected, retrieve the peer from the pending set
+	if diff, ok := ps.diffPend[id]; ok {
+		delete(ps.diffPend, id)
+
+		ps.lock.Unlock()
+		return diff, nil
+	}
+	// Otherwise wait for `diff` to connect concurrently
+	wait := make(chan *diff.Peer)
+	ps.diffWait[id] = wait
+	ps.lock.Unlock()
+
+	return <-wait, nil
+}
+
+func (ps *peerSet) GetDiffPeer(pid string) downloader.IDiffPeer {
+	if p := ps.peer(pid); p != nil && p.diffExt != nil {
+		return p.diffExt
+	}
+	return nil
+}
+
 // registerPeer injects a new `eth` peer into the working set, or returns an error
 // if the peer is already known.
-func (ps *peerSet) registerPeer(peer *eth.Peer, ext *snap.Peer) error {
+func (ps *peerSet) registerPeer(peer *eth.Peer, ext *snap.Peer, diffExt *diff.Peer) error {
 	// Start tracking the new peer
 	ps.lock.Lock()
 	defer ps.lock.Unlock()
@@ -152,6 +234,9 @@ func (ps *peerSet) registerPeer(peer *eth.Peer, ext *snap.Peer) error {
 		eth.snapExt = &snapPeer{ext}
 		ps.snapPeers++
 	}
+	if diffExt != nil {
+		eth.diffExt = &diffPeer{diffExt}
+	}
 	ps.peers[id] = eth
 	return nil
 }

+ 32 - 0
eth/protocols/diff/discovery.go

@@ -0,0 +1,32 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package diff
+
+import (
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+// enrEntry is the ENR entry which advertises `diff` protocol on the discovery.
+type enrEntry struct {
+	// Ignore additional fields (for forward compatibility).
+	Rest []rlp.RawValue `rlp:"tail"`
+}
+
+// ENRKey implements enr.Entry.
+func (e enrEntry) ENRKey() string {
+	return "diff"
+}

+ 180 - 0
eth/protocols/diff/handler.go

@@ -0,0 +1,180 @@
+package diff
+
+import (
+	"fmt"
+	"time"
+
+	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/metrics"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/p2p/enr"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+const (
+	// softResponseLimit is the target maximum size of replies to data retrievals.
+	softResponseLimit = 2 * 1024 * 1024
+
+	// maxDiffLayerServe is the maximum number of diff layers to serve.
+	maxDiffLayerServe = 1024
+)
+
+var requestTracker = NewTracker(time.Minute)
+
+// Handler is a callback to invoke from an outside runner after the boilerplate
+// exchanges have passed.
+type Handler func(peer *Peer) error
+
+type Backend interface {
+	// Chain retrieves the blockchain object to serve data.
+	Chain() *core.BlockChain
+
+	// RunPeer is invoked when a peer joins on the `eth` protocol. The handler
+	// should do any peer maintenance work, handshakes and validations. If all
+	// is passed, control should be given back to the `handler` to process the
+	// inbound messages going forward.
+	RunPeer(peer *Peer, handler Handler) error
+
+	PeerInfo(id enode.ID) interface{}
+
+	Handle(peer *Peer, packet Packet) error
+}
+
+// MakeProtocols constructs the P2P protocol definitions for `diff`.
+func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol {
+	// Filter the discovery iterator for nodes advertising diff support.
+	dnsdisc = enode.Filter(dnsdisc, func(n *enode.Node) bool {
+		var diff enrEntry
+		return n.Load(&diff) == nil
+	})
+
+	protocols := make([]p2p.Protocol, len(ProtocolVersions))
+	for i, version := range ProtocolVersions {
+		version := version // Closure
+
+		protocols[i] = p2p.Protocol{
+			Name:    ProtocolName,
+			Version: version,
+			Length:  protocolLengths[version],
+			Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+				return backend.RunPeer(NewPeer(version, p, rw), func(peer *Peer) error {
+					defer peer.Close()
+					return Handle(backend, peer)
+				})
+			},
+			NodeInfo: func() interface{} {
+				return nodeInfo(backend.Chain())
+			},
+			PeerInfo: func(id enode.ID) interface{} {
+				return backend.PeerInfo(id)
+			},
+			Attributes:     []enr.Entry{&enrEntry{}},
+			DialCandidates: dnsdisc,
+		}
+	}
+	return protocols
+}
+
+// Handle is the callback invoked to manage the life cycle of a `diff` peer.
+// When this function terminates, the peer is disconnected.
+func Handle(backend Backend, peer *Peer) error {
+	for {
+		if err := handleMessage(backend, peer); err != nil {
+			peer.Log().Debug("Message handling failed in `diff`", "err", err)
+			return err
+		}
+	}
+}
+
+// handleMessage is invoked whenever an inbound message is received from a
+// remote peer on the `diff` protocol. The remote connection is torn down upon
+// returning any error.
+func handleMessage(backend Backend, peer *Peer) error {
+	// Read the next message from the remote peer, and ensure it's fully consumed
+	msg, err := peer.rw.ReadMsg()
+	if err != nil {
+		return err
+	}
+	if msg.Size > maxMessageSize {
+		return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
+	}
+	defer msg.Discard()
+	start := time.Now()
+	// Track the emount of time it takes to serve the request and run the handler
+	if metrics.Enabled {
+		h := fmt.Sprintf("%s/%s/%d/%#02x", p2p.HandleHistName, ProtocolName, peer.Version(), msg.Code)
+		defer func(start time.Time) {
+			sampler := func() metrics.Sample {
+				return metrics.ResettingSample(
+					metrics.NewExpDecaySample(1028, 0.015),
+				)
+			}
+			metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(start).Microseconds())
+		}(start)
+	}
+	// Handle the message depending on its contents
+	switch {
+	case msg.Code == GetDiffLayerMsg:
+		res := new(GetDiffLayersPacket)
+		if err := msg.Decode(res); err != nil {
+			return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+		}
+		diffs := answerDiffLayersQuery(backend, res)
+
+		p2p.Send(peer.rw, FullDiffLayerMsg, &FullDiffLayersPacket{
+			RequestId:        res.RequestId,
+			DiffLayersPacket: diffs,
+		})
+		return nil
+
+	case msg.Code == DiffLayerMsg:
+		// A batch of trie nodes arrived to one of our previous requests
+		res := new(DiffLayersPacket)
+		if err := msg.Decode(res); err != nil {
+			return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+		}
+		return backend.Handle(peer, res)
+	case msg.Code == FullDiffLayerMsg:
+		// A batch of trie nodes arrived to one of our previous requests
+		res := new(FullDiffLayersPacket)
+		if err := msg.Decode(res); err != nil {
+			return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+		}
+		if fulfilled := requestTracker.Fulfil(peer.id, peer.version, FullDiffLayerMsg, res.RequestId); fulfilled {
+			return backend.Handle(peer, res)
+		}
+		return fmt.Errorf("%w: %v", errUnexpectedMsg, msg.Code)
+	default:
+		return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
+	}
+}
+
+func answerDiffLayersQuery(backend Backend, query *GetDiffLayersPacket) []rlp.RawValue {
+	// Gather blocks until the fetch or network limits is reached
+	var (
+		bytes      int
+		diffLayers []rlp.RawValue
+	)
+	// Need avoid transfer huge package
+	for lookups, hash := range query.BlockHashes {
+		if bytes >= softResponseLimit || len(diffLayers) >= maxDiffLayerServe ||
+			lookups >= 2*maxDiffLayerServe {
+			break
+		}
+		if data := backend.Chain().GetDiffLayerRLP(hash); len(data) != 0 {
+			diffLayers = append(diffLayers, data)
+			bytes += len(data)
+		}
+	}
+	return diffLayers
+}
+
+// NodeInfo represents a short summary of the `diff` sub-protocol metadata
+// known about the host peer.
+type NodeInfo struct{}
+
+// nodeInfo retrieves some `diff` protocol metadata about the running host node.
+func nodeInfo(_ *core.BlockChain) *NodeInfo {
+	return &NodeInfo{}
+}

+ 192 - 0
eth/protocols/diff/handler_test.go

@@ -0,0 +1,192 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package diff
+
+import (
+	"math/big"
+	"math/rand"
+	"testing"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/consensus/ethash"
+	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/core/vm"
+	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/ethdb"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/params"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+var (
+	// testKey is a private key to use for funding a tester account.
+	testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+
+	// testAddr is the Ethereum address of the tester account.
+	testAddr = crypto.PubkeyToAddress(testKey.PublicKey)
+)
+
+// testBackend is a mock implementation of the live Ethereum message handler. Its
+// purpose is to allow testing the request/reply workflows and wire serialization
+// in the `eth` protocol without actually doing any data processing.
+type testBackend struct {
+	db     ethdb.Database
+	chain  *core.BlockChain
+	txpool *core.TxPool
+}
+
+// newTestBackend creates an empty chain and wraps it into a mock backend.
+func newTestBackend(blocks int) *testBackend {
+	return newTestBackendWithGenerator(blocks)
+}
+
+// newTestBackend creates a chain with a number of explicitly defined blocks and
+// wraps it into a mock backend.
+func newTestBackendWithGenerator(blocks int) *testBackend {
+	signer := types.HomesteadSigner{}
+	// Create a database pre-initialize with a genesis block
+	db := rawdb.NewMemoryDatabase()
+	(&core.Genesis{
+		Config: params.TestChainConfig,
+		Alloc:  core.GenesisAlloc{testAddr: {Balance: big.NewInt(100000000000000000)}},
+	}).MustCommit(db)
+
+	chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil)
+	generator := func(i int, block *core.BlockGen) {
+		// The chain maker doesn't have access to a chain, so the difficulty will be
+		// lets unset (nil). Set it here to the correct value.
+		block.SetCoinbase(testAddr)
+
+		// We want to simulate an empty middle block, having the same state as the
+		// first one. The last is needs a state change again to force a reorg.
+		tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), common.Address{0x01}, big.NewInt(1), params.TxGas, big.NewInt(1), nil), signer, testKey)
+		if err != nil {
+			panic(err)
+		}
+		block.AddTxWithChain(chain, tx)
+	}
+	bs, _ := core.GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, generator)
+	if _, err := chain.InsertChain(bs); err != nil {
+		panic(err)
+	}
+	txconfig := core.DefaultTxPoolConfig
+	txconfig.Journal = "" // Don't litter the disk with test journals
+
+	return &testBackend{
+		db:     db,
+		chain:  chain,
+		txpool: core.NewTxPool(txconfig, params.TestChainConfig, chain),
+	}
+}
+
+// close tears down the transaction pool and chain behind the mock backend.
+func (b *testBackend) close() {
+	b.txpool.Stop()
+	b.chain.Stop()
+}
+
+func (b *testBackend) Chain() *core.BlockChain { return b.chain }
+
+func (b *testBackend) RunPeer(peer *Peer, handler Handler) error {
+	// Normally the backend would do peer mainentance and handshakes. All that
+	// is omitted and we will just give control back to the handler.
+	return handler(peer)
+}
+func (b *testBackend) PeerInfo(enode.ID) interface{} { panic("not implemented") }
+
+func (b *testBackend) Handle(*Peer, Packet) error {
+	panic("data processing tests should be done in the handler package")
+}
+
+func TestGetDiffLayers(t *testing.T) { testGetDiffLayers(t, Diff1) }
+
+func testGetDiffLayers(t *testing.T, protocol uint) {
+	t.Parallel()
+
+	blockNum := 2048
+	backend := newTestBackend(blockNum)
+	defer backend.close()
+
+	peer, _ := newTestPeer("peer", protocol, backend)
+	defer peer.close()
+
+	foundDiffBlockHashes := make([]common.Hash, 0)
+	foundDiffPackets := make([]FullDiffLayersPacket, 0)
+	foundDiffRlps := make([]rlp.RawValue, 0)
+	missDiffBlockHashes := make([]common.Hash, 0)
+	missDiffPackets := make([]FullDiffLayersPacket, 0)
+
+	for i := 0; i < 100; i++ {
+		number := uint64(rand.Int63n(1024))
+		if number == 0 {
+			continue
+		}
+		foundHash := backend.chain.GetCanonicalHash(number + 1024)
+		missHash := backend.chain.GetCanonicalHash(number)
+		foundRlp := backend.chain.GetDiffLayerRLP(foundHash)
+
+		if len(foundHash) == 0 {
+			t.Fatalf("Faild to fond rlp encoded diff layer %v", foundHash)
+		}
+		foundDiffPackets = append(foundDiffPackets, FullDiffLayersPacket{
+			RequestId:        uint64(i),
+			DiffLayersPacket: []rlp.RawValue{foundRlp},
+		})
+		foundDiffRlps = append(foundDiffRlps, foundRlp)
+
+		missDiffPackets = append(missDiffPackets, FullDiffLayersPacket{
+			RequestId:        uint64(i),
+			DiffLayersPacket: []rlp.RawValue{},
+		})
+
+		missDiffBlockHashes = append(missDiffBlockHashes, missHash)
+		foundDiffBlockHashes = append(foundDiffBlockHashes, foundHash)
+	}
+
+	for idx, blockHash := range foundDiffBlockHashes {
+		p2p.Send(peer.app, GetDiffLayerMsg, GetDiffLayersPacket{RequestId: uint64(idx), BlockHashes: []common.Hash{blockHash}})
+		if err := p2p.ExpectMsg(peer.app, FullDiffLayerMsg, foundDiffPackets[idx]); err != nil {
+			t.Errorf("test %d: diff layer mismatch: %v", idx, err)
+		}
+	}
+
+	for idx, blockHash := range missDiffBlockHashes {
+		p2p.Send(peer.app, GetDiffLayerMsg, GetDiffLayersPacket{RequestId: uint64(idx), BlockHashes: []common.Hash{blockHash}})
+		if err := p2p.ExpectMsg(peer.app, FullDiffLayerMsg, missDiffPackets[idx]); err != nil {
+			t.Errorf("test %d: diff layer mismatch: %v", idx, err)
+		}
+	}
+
+	p2p.Send(peer.app, GetDiffLayerMsg, GetDiffLayersPacket{RequestId: 111, BlockHashes: foundDiffBlockHashes})
+	if err := p2p.ExpectMsg(peer.app, FullDiffLayerMsg, FullDiffLayersPacket{
+		111,
+		foundDiffRlps,
+	}); err != nil {
+		t.Errorf("test: diff layer mismatch: %v", err)
+	}
+
+	p2p.Send(peer.app, GetDiffLayerMsg, GetDiffLayersPacket{RequestId: 111, BlockHashes: missDiffBlockHashes})
+	if err := p2p.ExpectMsg(peer.app, FullDiffLayerMsg, FullDiffLayersPacket{
+		111,
+		nil,
+	}); err != nil {
+		t.Errorf("test: diff layer mismatch: %v", err)
+	}
+}

+ 82 - 0
eth/protocols/diff/handshake.go

@@ -0,0 +1,82 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package diff
+
+import (
+	"fmt"
+	"time"
+
+	"github.com/ethereum/go-ethereum/common/gopool"
+	"github.com/ethereum/go-ethereum/p2p"
+)
+
+const (
+	// handshakeTimeout is the maximum allowed time for the `diff` handshake to
+	// complete before dropping the connection.= as malicious.
+	handshakeTimeout = 5 * time.Second
+)
+
+// Handshake executes the diff protocol handshake,
+func (p *Peer) Handshake(diffSync bool) error {
+	// Send out own handshake in a new thread
+	errc := make(chan error, 2)
+
+	var cap DiffCapPacket // safe to read after two values have been received from errc
+
+	gopool.Submit(func() {
+		errc <- p2p.Send(p.rw, DiffCapMsg, &DiffCapPacket{
+			DiffSync: diffSync,
+			Extra:    defaultExtra,
+		})
+	})
+	gopool.Submit(func() {
+		errc <- p.readCap(&cap)
+	})
+	timeout := time.NewTimer(handshakeTimeout)
+	defer timeout.Stop()
+	for i := 0; i < 2; i++ {
+		select {
+		case err := <-errc:
+			if err != nil {
+				return err
+			}
+		case <-timeout.C:
+			return p2p.DiscReadTimeout
+		}
+	}
+	p.diffSync = cap.DiffSync
+	return nil
+}
+
+// readStatus reads the remote handshake message.
+func (p *Peer) readCap(cap *DiffCapPacket) error {
+	msg, err := p.rw.ReadMsg()
+	if err != nil {
+		return err
+	}
+	if msg.Code != DiffCapMsg {
+		return fmt.Errorf("%w: first msg has code %x (!= %x)", errNoCapMsg, msg.Code, DiffCapMsg)
+	}
+	if msg.Size > maxMessageSize {
+		return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
+	}
+	// Decode the handshake and make sure everything matches
+	if err := msg.Decode(cap); err != nil {
+		return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+	}
+	return nil
+}

+ 107 - 0
eth/protocols/diff/peer.go

@@ -0,0 +1,107 @@
+package diff
+
+import (
+	"math/rand"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/log"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+const maxQueuedDiffLayers = 12
+
+// Peer is a collection of relevant information we have about a `diff` peer.
+type Peer struct {
+	id               string              // Unique ID for the peer, cached
+	diffSync         bool                // whether the peer can diff sync
+	queuedDiffLayers chan []rlp.RawValue // Queue of diff layers to broadcast to the peer
+
+	*p2p.Peer                   // The embedded P2P package peer
+	rw        p2p.MsgReadWriter // Input/output streams for diff
+	version   uint              // Protocol version negotiated
+	logger    log.Logger        // Contextual logger with the peer id injected
+	term      chan struct{}     // Termination channel to stop the broadcasters
+}
+
+// NewPeer create a wrapper for a network connection and negotiated  protocol
+// version.
+func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer {
+	id := p.ID().String()
+	peer := &Peer{
+		id:               id,
+		Peer:             p,
+		rw:               rw,
+		diffSync:         false,
+		version:          version,
+		logger:           log.New("peer", id[:8]),
+		queuedDiffLayers: make(chan []rlp.RawValue, maxQueuedDiffLayers),
+		term:             make(chan struct{}),
+	}
+	go peer.broadcastDiffLayers()
+	return peer
+}
+
+func (p *Peer) broadcastDiffLayers() {
+	for {
+		select {
+		case prop := <-p.queuedDiffLayers:
+			if err := p.SendDiffLayers(prop); err != nil {
+				p.Log().Error("Failed to propagated diff layer", "err", err)
+				return
+			}
+		case <-p.term:
+			return
+		}
+	}
+}
+
+// ID retrieves the peer's unique identifier.
+func (p *Peer) ID() string {
+	return p.id
+}
+
+// Version retrieves the peer's negoatiated `diff` protocol version.
+func (p *Peer) Version() uint {
+	return p.version
+}
+
+func (p *Peer) DiffSync() bool {
+	return p.diffSync
+}
+
+// Log overrides the P2P logget with the higher level one containing only the id.
+func (p *Peer) Log() log.Logger {
+	return p.logger
+}
+
+// Close signals the broadcast goroutine to terminate. Only ever call this if
+// you created the peer yourself via NewPeer. Otherwise let whoever created it
+// clean it up!
+func (p *Peer) Close() {
+	close(p.term)
+}
+
+// RequestDiffLayers fetches a batch of diff layers corresponding to the hashes
+// specified.
+func (p *Peer) RequestDiffLayers(hashes []common.Hash) error {
+	id := rand.Uint64()
+
+	requestTracker.Track(p.id, p.version, GetDiffLayerMsg, FullDiffLayerMsg, id)
+	return p2p.Send(p.rw, GetDiffLayerMsg, GetDiffLayersPacket{
+		RequestId:   id,
+		BlockHashes: hashes,
+	})
+}
+
+func (p *Peer) SendDiffLayers(diffs []rlp.RawValue) error {
+	return p2p.Send(p.rw, DiffLayerMsg, diffs)
+}
+
+func (p *Peer) AsyncSendDiffLayer(diffLayers []rlp.RawValue) {
+	select {
+	case p.queuedDiffLayers <- diffLayers:
+	default:
+		p.Log().Debug("Dropping diff layers propagation")
+	}
+}

+ 61 - 0
eth/protocols/diff/peer_test.go

@@ -0,0 +1,61 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+// This file contains some shares testing functionality, common to  multiple
+// different files and modules being tested.
+
+package diff
+
+import (
+	"crypto/rand"
+
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+// testPeer is a simulated peer to allow testing direct network calls.
+type testPeer struct {
+	*Peer
+
+	net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
+	app *p2p.MsgPipeRW    // Application layer reader/writer to simulate the local side
+}
+
+// newTestPeer creates a new peer registered at the given data backend.
+func newTestPeer(name string, version uint, backend Backend) (*testPeer, <-chan error) {
+	// Create a message pipe to communicate through
+	app, net := p2p.MsgPipe()
+
+	// Start the peer on a new thread
+	var id enode.ID
+	rand.Read(id[:])
+
+	peer := NewPeer(version, p2p.NewPeer(id, name, nil), net)
+	errc := make(chan error, 1)
+	go func() {
+		errc <- backend.RunPeer(peer, func(peer *Peer) error {
+			return Handle(backend, peer)
+		})
+	}()
+	return &testPeer{app: app, net: net, Peer: peer}, errc
+}
+
+// close terminates the local side of the peer, notifying the remote protocol
+// manager of termination.
+func (p *testPeer) close() {
+	p.Peer.Close()
+	p.app.Close()
+}

+ 122 - 0
eth/protocols/diff/protocol.go

@@ -0,0 +1,122 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package diff
+
+import (
+	"errors"
+	"fmt"
+
+	"golang.org/x/crypto/sha3"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+// Constants to match up protocol versions and messages
+const (
+	Diff1 = 1
+)
+
+// ProtocolName is the official short name of the `diff` protocol used during
+// devp2p capability negotiation.
+const ProtocolName = "diff"
+
+// ProtocolVersions are the supported versions of the `diff` protocol (first
+// is primary).
+var ProtocolVersions = []uint{Diff1}
+
+// protocolLengths are the number of implemented message corresponding to
+// different protocol versions.
+var protocolLengths = map[uint]uint64{Diff1: 4}
+
+// maxMessageSize is the maximum cap on the size of a protocol message.
+const maxMessageSize = 10 * 1024 * 1024
+
+const (
+	DiffCapMsg       = 0x00
+	GetDiffLayerMsg  = 0x01
+	DiffLayerMsg     = 0x02
+	FullDiffLayerMsg = 0x03
+)
+
+var defaultExtra = []byte{0x00}
+
+var (
+	errMsgTooLarge    = errors.New("message too long")
+	errDecode         = errors.New("invalid message")
+	errInvalidMsgCode = errors.New("invalid message code")
+	errUnexpectedMsg  = errors.New("unexpected message code")
+	errNoCapMsg       = errors.New("miss cap message during handshake")
+)
+
+// Packet represents a p2p message in the `diff` protocol.
+type Packet interface {
+	Name() string // Name returns a string corresponding to the message type.
+	Kind() byte   // Kind returns the message type.
+}
+
+type GetDiffLayersPacket struct {
+	RequestId   uint64
+	BlockHashes []common.Hash
+}
+
+func (p *DiffLayersPacket) Unpack() ([]*types.DiffLayer, error) {
+	diffLayers := make([]*types.DiffLayer, 0, len(*p))
+	hasher := sha3.NewLegacyKeccak256()
+	for _, rawData := range *p {
+		var diff types.DiffLayer
+		err := rlp.DecodeBytes(rawData, &diff)
+		if err != nil {
+			return nil, fmt.Errorf("%w: diff layer %v", errDecode, err)
+		}
+		diffLayers = append(diffLayers, &diff)
+		_, err = hasher.Write(rawData)
+		if err != nil {
+			return nil, err
+		}
+		var diffHash common.Hash
+		hasher.Sum(diffHash[:0])
+		hasher.Reset()
+		diff.DiffHash = diffHash
+	}
+	return diffLayers, nil
+}
+
+type DiffCapPacket struct {
+	DiffSync bool
+	Extra    rlp.RawValue // for extension
+}
+
+type DiffLayersPacket []rlp.RawValue
+
+type FullDiffLayersPacket struct {
+	RequestId uint64
+	DiffLayersPacket
+}
+
+func (*GetDiffLayersPacket) Name() string { return "GetDiffLayers" }
+func (*GetDiffLayersPacket) Kind() byte   { return GetDiffLayerMsg }
+
+func (*DiffLayersPacket) Name() string { return "DiffLayers" }
+func (*DiffLayersPacket) Kind() byte   { return DiffLayerMsg }
+
+func (*FullDiffLayersPacket) Name() string { return "FullDiffLayers" }
+func (*FullDiffLayersPacket) Kind() byte   { return FullDiffLayerMsg }
+
+func (*DiffCapPacket) Name() string { return "DiffCap" }
+func (*DiffCapPacket) Kind() byte   { return DiffCapMsg }

+ 131 - 0
eth/protocols/diff/protocol_test.go

@@ -0,0 +1,131 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package diff
+
+import (
+	"bytes"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+// Tests that the custom union field encoder and decoder works correctly.
+func TestDiffLayersPacketDataEncodeDecode(t *testing.T) {
+	// Create a "random" hash for testing
+	var hash common.Hash
+	for i := range hash {
+		hash[i] = byte(i)
+	}
+
+	testDiffLayers := []*types.DiffLayer{
+		{
+			BlockHash: common.HexToHash("0x1e9624dcd0874958723aa3dae1fe299861e93ef32b980143d798c428bdd7a20a"),
+			Number:    10479133,
+			Receipts: []*types.Receipt{{
+				GasUsed:          100,
+				TransactionIndex: 1,
+			}},
+			Codes: []types.DiffCode{{
+				Hash: common.HexToHash("0xaece2dbf80a726206cf4df299afa09f9d8f3dcd85ff39bb4b3f0402a3a6af2f5"),
+				Code: []byte{1, 2, 3, 4},
+			}},
+			Destructs: []common.Address{
+				common.HexToAddress("0x0205bb28ece9289d3fb8eb0c9e999bbd5be2b931"),
+			},
+			Accounts: []types.DiffAccount{{
+				Account: common.HexToAddress("0x18b2a687610328590bc8f2e5fedde3b582a49cda"),
+				Blob:    []byte{2, 3, 4, 5},
+			}},
+			Storages: []types.DiffStorage{{
+				Account: common.HexToAddress("0x18b2a687610328590bc8f2e5fedde3b582a49cda"),
+				Keys:    []string{"abc"},
+				Vals:    [][]byte{{1, 2, 3}},
+			}},
+		},
+	}
+	// Assemble some table driven tests
+	tests := []struct {
+		diffLayers []*types.DiffLayer
+		fail       bool
+	}{
+		{fail: false, diffLayers: testDiffLayers},
+	}
+	// Iterate over each of the tests and try to encode and then decode
+	for i, tt := range tests {
+		originPacket := make([]rlp.RawValue, 0)
+		for _, diff := range tt.diffLayers {
+			bz, err := rlp.EncodeToBytes(diff)
+			assert.NoError(t, err)
+			originPacket = append(originPacket, bz)
+		}
+
+		bz, err := rlp.EncodeToBytes(DiffLayersPacket(originPacket))
+		if err != nil && !tt.fail {
+			t.Fatalf("test %d: failed to encode packet: %v", i, err)
+		} else if err == nil && tt.fail {
+			t.Fatalf("test %d: encode should have failed", i)
+		}
+		if !tt.fail {
+			packet := new(DiffLayersPacket)
+			if err := rlp.DecodeBytes(bz, packet); err != nil {
+				t.Fatalf("test %d: failed to decode packet: %v", i, err)
+			}
+			diffLayers, err := packet.Unpack()
+			assert.NoError(t, err)
+
+			if len(diffLayers) != len(tt.diffLayers) {
+				t.Fatalf("test %d: encode length mismatch: have %+v, want %+v", i, len(diffLayers), len(tt.diffLayers))
+			}
+			expectedPacket := make([]rlp.RawValue, 0)
+			for _, diff := range diffLayers {
+				bz, err := rlp.EncodeToBytes(diff)
+				assert.NoError(t, err)
+				expectedPacket = append(expectedPacket, bz)
+			}
+			for i := 0; i < len(expectedPacket); i++ {
+				if !bytes.Equal(expectedPacket[i], originPacket[i]) {
+					t.Fatalf("test %d: data change during encode and decode", i)
+				}
+			}
+		}
+	}
+}
+
+func TestDiffMessages(t *testing.T) {
+
+	for i, tc := range []struct {
+		message interface{}
+		want    []byte
+	}{
+		{
+			DiffCapPacket{true, defaultExtra},
+			common.FromHex("c20100"),
+		},
+		{
+			GetDiffLayersPacket{1111, []common.Hash{common.HexToHash("0xaece2dbf80a726206cf4df299afa09f9d8f3dcd85ff39bb4b3f0402a3a6af2f5")}},
+			common.FromHex("e5820457e1a0aece2dbf80a726206cf4df299afa09f9d8f3dcd85ff39bb4b3f0402a3a6af2f5"),
+		},
+	} {
+		if have, _ := rlp.EncodeToBytes(tc.message); !bytes.Equal(have, tc.want) {
+			t.Errorf("test %d, type %T, have\n\t%x\nwant\n\t%x", i, tc.message, have, tc.want)
+		}
+	}
+}

+ 161 - 0
eth/protocols/diff/tracker.go

@@ -0,0 +1,161 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package diff
+
+import (
+	"container/list"
+	"fmt"
+	"sync"
+	"time"
+
+	"github.com/ethereum/go-ethereum/log"
+)
+
+const (
+	// maxTrackedPackets is a huge number to act as a failsafe on the number of
+	// pending requests the node will track. It should never be hit unless an
+	// attacker figures out a way to spin requests.
+	maxTrackedPackets = 10000
+)
+
+// request tracks sent network requests which have not yet received a response.
+type request struct {
+	peer    string
+	version uint // Protocol version
+
+	reqCode uint64 // Protocol message code of the request
+	resCode uint64 // Protocol message code of the expected response
+
+	time   time.Time     // Timestamp when the request was made
+	expire *list.Element // Expiration marker to untrack it
+}
+
+type Tracker struct {
+	timeout time.Duration // Global timeout after which to drop a tracked packet
+
+	pending map[uint64]*request // Currently pending requests
+	expire  *list.List          // Linked list tracking the expiration order
+	wake    *time.Timer         // Timer tracking the expiration of the next item
+
+	lock sync.Mutex // Lock protecting from concurrent updates
+}
+
+func NewTracker(timeout time.Duration) *Tracker {
+	return &Tracker{
+		timeout: timeout,
+		pending: make(map[uint64]*request),
+		expire:  list.New(),
+	}
+}
+
+// Track adds a network request to the tracker to wait for a response to arrive
+// or until the request it cancelled or times out.
+func (t *Tracker) Track(peer string, version uint, reqCode uint64, resCode uint64, id uint64) {
+	t.lock.Lock()
+	defer t.lock.Unlock()
+
+	// If there's a duplicate request, we've just random-collided (or more probably,
+	// we have a bug), report it. We could also add a metric, but we're not really
+	// expecting ourselves to be buggy, so a noisy warning should be enough.
+	if _, ok := t.pending[id]; ok {
+		log.Error("Network request id collision", "version", version, "code", reqCode, "id", id)
+		return
+	}
+	// If we have too many pending requests, bail out instead of leaking memory
+	if pending := len(t.pending); pending >= maxTrackedPackets {
+		log.Error("Request tracker exceeded allowance", "pending", pending, "peer", peer, "version", version, "code", reqCode)
+		return
+	}
+	// Id doesn't exist yet, start tracking it
+	t.pending[id] = &request{
+		peer:    peer,
+		version: version,
+		reqCode: reqCode,
+		resCode: resCode,
+		time:    time.Now(),
+		expire:  t.expire.PushBack(id),
+	}
+
+	// If we've just inserted the first item, start the expiration timer
+	if t.wake == nil {
+		t.wake = time.AfterFunc(t.timeout, t.clean)
+	}
+}
+
+// clean is called automatically when a preset time passes without a response
+// being dleivered for the first network request.
+func (t *Tracker) clean() {
+	t.lock.Lock()
+	defer t.lock.Unlock()
+
+	// Expire anything within a certain threshold (might be no items at all if
+	// we raced with the delivery)
+	for t.expire.Len() > 0 {
+		// Stop iterating if the next pending request is still alive
+		var (
+			head = t.expire.Front()
+			id   = head.Value.(uint64)
+			req  = t.pending[id]
+		)
+		if time.Since(req.time) < t.timeout+5*time.Millisecond {
+			break
+		}
+		// Nope, dead, drop it
+		t.expire.Remove(head)
+		delete(t.pending, id)
+	}
+	t.schedule()
+}
+
+// schedule starts a timer to trigger on the expiration of the first network
+// packet.
+func (t *Tracker) schedule() {
+	if t.expire.Len() == 0 {
+		t.wake = nil
+		return
+	}
+	t.wake = time.AfterFunc(time.Until(t.pending[t.expire.Front().Value.(uint64)].time.Add(t.timeout)), t.clean)
+}
+
+// Fulfil fills a pending request, if any is available.
+func (t *Tracker) Fulfil(peer string, version uint, code uint64, id uint64) bool {
+	t.lock.Lock()
+	defer t.lock.Unlock()
+
+	// If it's a non existing request, track as stale response
+	req, ok := t.pending[id]
+	if !ok {
+		return false
+	}
+	// If the response is funky, it might be some active attack
+	if req.peer != peer || req.version != version || req.resCode != code {
+		log.Warn("Network response id collision",
+			"have", fmt.Sprintf("%s:/%d:%d", peer, version, code),
+			"want", fmt.Sprintf("%s:/%d:%d", peer, req.version, req.resCode),
+		)
+		return false
+	}
+	// Everything matches, mark the request serviced
+	t.expire.Remove(req.expire)
+	delete(t.pending, id)
+	if req.expire.Prev() == nil {
+		if t.wake.Stop() {
+			t.schedule()
+		}
+	}
+	return true
+}

+ 2 - 2
eth/state_accessor.go

@@ -114,12 +114,12 @@ func (eth *Ethereum) stateAtBlock(block *types.Block, reexec uint64, base *state
 		if current = eth.blockchain.GetBlockByNumber(next); current == nil {
 			return nil, fmt.Errorf("block #%d not found", next)
 		}
-		_, _, _, err := eth.blockchain.Processor().Process(current, statedb, vm.Config{})
+		statedb, _, _, _, err := eth.blockchain.Processor().Process(current, statedb, vm.Config{})
 		if err != nil {
 			return nil, fmt.Errorf("processing block %d failed: %v", current.NumberU64(), err)
 		}
 		// Finalize the state so any modifications are written to the trie
-		root, err := statedb.Commit(eth.blockchain.Config().IsEIP158(current.Number()))
+		root, _, err := statedb.Commit(eth.blockchain.Config().IsEIP158(current.Number()))
 		if err != nil {
 			return nil, err
 		}

+ 1 - 1
eth/tracers/tracers_test.go

@@ -357,7 +357,7 @@ func BenchmarkTransactionTrace(b *testing.B) {
 		//DisableReturnData: true,
 	})
 	evm := vm.NewEVM(context, txContext, statedb, params.AllEthashProtocolChanges, vm.Config{Debug: true, Tracer: tracer})
-	msg, err := tx.AsMessage(signer, nil)
+	msg, err := tx.AsMessage(signer)
 	if err != nil {
 		b.Fatalf("failed to prepare transaction for tracing: %v", err)
 	}

+ 2 - 68
ethclient/ethclient_test.go

@@ -17,7 +17,6 @@
 package ethclient
 
 import (
-	"bytes"
 	"context"
 	"errors"
 	"fmt"
@@ -262,9 +261,7 @@ func TestEthClient(t *testing.T) {
 		"TestCallContract": {
 			func(t *testing.T) { testCallContract(t, client) },
 		},
-		"TestAtFunctions": {
-			func(t *testing.T) { testAtFunctions(t, client) },
-		},
+		// DO not have TestAtFunctions now, because we do not have pending block now
 	}
 
 	t.Parallel()
@@ -490,69 +487,6 @@ func testCallContract(t *testing.T, client *rpc.Client) {
 	}
 }
 
-func testAtFunctions(t *testing.T, client *rpc.Client) {
-	ec := NewClient(client)
-	// send a transaction for some interesting pending status
-	sendTransaction(ec)
-	time.Sleep(100 * time.Millisecond)
-	// Check pending transaction count
-	pending, err := ec.PendingTransactionCount(context.Background())
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	if pending != 1 {
-		t.Fatalf("unexpected pending, wanted 1 got: %v", pending)
-	}
-	// Query balance
-	balance, err := ec.BalanceAt(context.Background(), testAddr, nil)
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	penBalance, err := ec.PendingBalanceAt(context.Background(), testAddr)
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	if balance.Cmp(penBalance) == 0 {
-		t.Fatalf("unexpected balance: %v %v", balance, penBalance)
-	}
-	// NonceAt
-	nonce, err := ec.NonceAt(context.Background(), testAddr, nil)
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	penNonce, err := ec.PendingNonceAt(context.Background(), testAddr)
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	if penNonce != nonce+1 {
-		t.Fatalf("unexpected nonce: %v %v", nonce, penNonce)
-	}
-	// StorageAt
-	storage, err := ec.StorageAt(context.Background(), testAddr, common.Hash{}, nil)
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	penStorage, err := ec.PendingStorageAt(context.Background(), testAddr, common.Hash{})
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	if !bytes.Equal(storage, penStorage) {
-		t.Fatalf("unexpected storage: %v %v", storage, penStorage)
-	}
-	// CodeAt
-	code, err := ec.CodeAt(context.Background(), testAddr, nil)
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	penCode, err := ec.PendingCodeAt(context.Background(), testAddr)
-	if err != nil {
-		t.Fatalf("unexpected error: %v", err)
-	}
-	if !bytes.Equal(code, penCode) {
-		t.Fatalf("unexpected code: %v %v", code, penCode)
-	}
-}
-
 func sendTransaction(ec *Client) error {
 	// Retrieve chainID
 	chainID, err := ec.ChainID(context.Background())
@@ -560,7 +494,7 @@ func sendTransaction(ec *Client) error {
 		return err
 	}
 	// Create transaction
-	tx := types.NewTransaction(0, common.Address{1}, big.NewInt(1), 22000, big.NewInt(1), nil)
+	tx := types.NewTransaction(0, common.Address{1}, big.NewInt(1), 23000, big.NewInt(100000), nil)
 	signer := types.LatestSignerForChainID(chainID)
 	signature, err := crypto.Sign(signer.Hash(tx).Bytes(), testKey)
 	if err != nil {

+ 6 - 0
ethdb/database.go

@@ -118,11 +118,17 @@ type AncientStore interface {
 	io.Closer
 }
 
+type DiffStore interface {
+	DiffStore() KeyValueStore
+	SetDiffStore(diff KeyValueStore)
+}
+
 // Database contains all the methods required by the high level database to not
 // only access the key-value data store but also the chain freezer.
 type Database interface {
 	Reader
 	Writer
+	DiffStore
 	Batcher
 	Iteratee
 	Stater

+ 2 - 2
les/fetcher.go

@@ -339,7 +339,7 @@ func (f *lightFetcher) mainloop() {
 					log.Debug("Trigger light sync", "peer", peerid, "local", localHead.Number, "localhash", localHead.Hash(), "remote", data.Number, "remotehash", data.Hash)
 					continue
 				}
-				f.fetcher.Notify(peerid.String(), data.Hash, data.Number, time.Now(), f.requestHeaderByHash(peerid), nil)
+				f.fetcher.Notify(peerid.String(), data.Hash, data.Number, time.Now(), f.requestHeaderByHash(peerid), nil, nil)
 				log.Debug("Trigger header retrieval", "peer", peerid, "number", data.Number, "hash", data.Hash)
 			}
 			// Keep collecting announces from trusted server even we are syncing.
@@ -355,7 +355,7 @@ func (f *lightFetcher) mainloop() {
 						continue
 					}
 					p := agreed[rand.Intn(len(agreed))]
-					f.fetcher.Notify(p.String(), data.Hash, data.Number, time.Now(), f.requestHeaderByHash(p), nil)
+					f.fetcher.Notify(p.String(), data.Hash, data.Number, time.Now(), f.requestHeaderByHash(p), nil, nil)
 					log.Debug("Trigger trusted header retrieval", "number", data.Number, "hash", data.Hash)
 				}
 			}

+ 1 - 1
les/peer.go

@@ -1054,7 +1054,7 @@ func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge
 
 			// If local ethereum node is running in archive mode, advertise ourselves we have
 			// all version state data. Otherwise only recent state is available.
-			stateRecent := uint64(server.handler.blockchain.TriesInMemory() - blockSafetyMargin)
+			stateRecent := server.handler.blockchain.TriesInMemory() - blockSafetyMargin
 			if server.archiveMode {
 				stateRecent = 0
 			}

+ 3 - 9
light/trie.go

@@ -95,17 +95,11 @@ func (db *odrDatabase) TrieDB() *trie.Database {
 	return nil
 }
 
-func (db *odrDatabase) CacheAccount(_ common.Hash, _ state.Trie) {
-	return
-}
+func (db *odrDatabase) CacheAccount(_ common.Hash, _ state.Trie) {}
 
-func (db *odrDatabase) CacheStorage(_ common.Hash, _ common.Hash, _ state.Trie) {
-	return
-}
+func (db *odrDatabase) CacheStorage(_ common.Hash, _ common.Hash, _ state.Trie) {}
 
-func (db *odrDatabase) Purge() {
-	return
-}
+func (db *odrDatabase) Purge() {}
 
 type odrTrie struct {
 	db   *odrDatabase

+ 0 - 19
miner/worker.go

@@ -1011,16 +1011,6 @@ func (w *worker) commit(uncles []*types.Header, interval func(), update bool, st
 	return nil
 }
 
-// copyReceipts makes a deep copy of the given receipts.
-func copyReceipts(receipts []*types.Receipt) []*types.Receipt {
-	result := make([]*types.Receipt, len(receipts))
-	for i, l := range receipts {
-		cpy := *l
-		result[i] = &cpy
-	}
-	return result
-}
-
 // postSideBlock fires a side chain event, only use it for testing.
 func (w *worker) postSideBlock(event core.ChainSideEvent) {
 	select {
@@ -1028,12 +1018,3 @@ func (w *worker) postSideBlock(event core.ChainSideEvent) {
 	case <-w.exitCh:
 	}
 }
-
-// totalFees computes total consumed fees in ETH. Block transactions and receipts have to have the same order.
-func totalFees(block *types.Block, receipts []*types.Receipt) *big.Float {
-	feesWei := new(big.Int)
-	for i, tx := range block.Transactions() {
-		feesWei.Add(feesWei, new(big.Int).Mul(new(big.Int).SetUint64(receipts[i].GasUsed), tx.GasPrice()))
-	}
-	return new(big.Float).Quo(new(big.Float).SetInt(feesWei), new(big.Float).SetInt(big.NewInt(params.Ether)))
-}

+ 3 - 0
node/config.go

@@ -98,6 +98,9 @@ type Config struct {
 	// DirectBroadcast enable directly broadcast mined block to all peers
 	DirectBroadcast bool `toml:",omitempty"`
 
+	// DisableSnapProtocol disable the snap protocol
+	DisableSnapProtocol bool `toml:",omitempty"`
+
 	// RangeLimit enable 5000 blocks limit when handle range query
 	RangeLimit bool `toml:",omitempty"`
 

+ 41 - 0
node/node.go

@@ -30,6 +30,7 @@ import (
 	"github.com/ethereum/go-ethereum/accounts"
 	"github.com/ethereum/go-ethereum/core/rawdb"
 	"github.com/ethereum/go-ethereum/ethdb"
+	"github.com/ethereum/go-ethereum/ethdb/leveldb"
 	"github.com/ethereum/go-ethereum/event"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/p2p"
@@ -578,6 +579,22 @@ func (n *Node) OpenDatabase(name string, cache, handles int, namespace string, r
 	return db, err
 }
 
+func (n *Node) OpenAndMergeDatabase(name string, cache, handles int, freezer, diff, namespace string, readonly, persistDiff bool) (ethdb.Database, error) {
+	chainDB, err := n.OpenDatabaseWithFreezer(name, cache, handles, freezer, namespace, readonly)
+	if err != nil {
+		return nil, err
+	}
+	if persistDiff {
+		diffStore, err := n.OpenDiffDatabase(name, handles, diff, namespace, readonly)
+		if err != nil {
+			chainDB.Close()
+			return nil, err
+		}
+		chainDB.SetDiffStore(diffStore)
+	}
+	return chainDB, nil
+}
+
 // OpenDatabaseWithFreezer opens an existing database with the given name (or
 // creates one if no previous can be found) from within the node's data directory,
 // also attaching a chain freezer to it that moves ancient chain data from the
@@ -611,6 +628,30 @@ func (n *Node) OpenDatabaseWithFreezer(name string, cache, handles int, freezer,
 	return db, err
 }
 
+func (n *Node) OpenDiffDatabase(name string, handles int, diff, namespace string, readonly bool) (*leveldb.Database, error) {
+	n.lock.Lock()
+	defer n.lock.Unlock()
+	if n.state == closedState {
+		return nil, ErrNodeStopped
+	}
+
+	var db *leveldb.Database
+	var err error
+	if n.config.DataDir == "" {
+		panic("datadir is missing")
+	}
+	root := n.ResolvePath(name)
+	switch {
+	case diff == "":
+		diff = filepath.Join(root, "diff")
+	case !filepath.IsAbs(diff):
+		diff = n.ResolvePath(diff)
+	}
+	db, err = leveldb.New(diff, 0, handles, namespace, readonly)
+
+	return db, err
+}
+
 // ResolvePath returns the absolute path of a resource in the instance directory.
 func (n *Node) ResolvePath(x string) string {
 	return n.config.ResolvePath(x)

+ 0 - 10
rlp/typecache.go

@@ -172,16 +172,6 @@ func structFields(typ reflect.Type) (fields []field, err error) {
 	return fields, nil
 }
 
-// anyOptionalFields returns the index of the first field with "optional" tag.
-func firstOptionalField(fields []field) int {
-	for i, f := range fields {
-		if f.optional {
-			return i
-		}
-	}
-	return len(fields)
-}
-
 type structFieldError struct {
 	typ   reflect.Type
 	field int

+ 1 - 1
tests/state_test_util.go

@@ -226,7 +226,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter boo
 		}
 	}
 	// Commit and re-open to start with a clean state.
-	root, _ := statedb.Commit(false)
+	root, _, _ := statedb.Commit(false)
 
 	var snaps *snapshot.Tree
 	if snapshotter {