Browse Source

Merge pull request #14723 from Arachnid/downloadrefactor

Refactor downloader to use interfaces instead of callbacks
Nick Johnson 8 năm trước cách đây
mục cha
commit
4b8860a7b4

+ 130 - 105
eth/downloader/downloader.go

@@ -114,21 +114,11 @@ type Downloader struct {
 	syncStatsState       stateSyncStats
 	syncStatsLock        sync.RWMutex // Lock protecting the sync stats fields
 
+	lightchain LightChain
+	blockchain BlockChain
+
 	// Callbacks
-	hasHeader        headerCheckFn            // Checks if a header is present in the chain
-	hasBlockAndState blockAndStateCheckFn     // Checks if a block and associated state is present in the chain
-	getHeader        headerRetrievalFn        // Retrieves a header from the chain
-	getBlock         blockRetrievalFn         // Retrieves a block from the chain
-	headHeader       headHeaderRetrievalFn    // Retrieves the head header from the chain
-	headBlock        headBlockRetrievalFn     // Retrieves the head block from the chain
-	headFastBlock    headFastBlockRetrievalFn // Retrieves the head fast-sync block from the chain
-	commitHeadBlock  headBlockCommitterFn     // Commits a manually assembled block as the chain head
-	getTd            tdRetrievalFn            // Retrieves the TD of a block from the chain
-	insertHeaders    headerChainInsertFn      // Injects a batch of headers into the chain
-	insertBlocks     blockChainInsertFn       // Injects a batch of blocks into the chain
-	insertReceipts   receiptChainInsertFn     // Injects a batch of blocks and their receipts into the chain
-	rollback         chainRollbackFn          // Removes a batch of recently added chain links
-	dropPeer         peerDropFn               // Drops a peer for misbehaving
+	dropPeer peerDropFn // Drops a peer for misbehaving
 
 	// Status
 	synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing
@@ -163,45 +153,80 @@ type Downloader struct {
 	chainInsertHook  func([]*fetchResult)  // Method to call upon inserting a chain of blocks (possibly in multiple invocations)
 }
 
+// LightChain encapsulates functions required to synchronise a light chain.
+type LightChain interface {
+	// HasHeader verifies a header's presence in the local chain.
+	HasHeader(common.Hash) bool
+
+	// GetHeaderByHash retrieves a header from the local chain.
+	GetHeaderByHash(common.Hash) *types.Header
+
+	// CurrentHeader retrieves the head header from the local chain.
+	CurrentHeader() *types.Header
+
+	// GetTdByHash returns the total difficulty of a local block.
+	GetTdByHash(common.Hash) *big.Int
+
+	// InsertHeaderChain inserts a batch of headers into the local chain.
+	InsertHeaderChain([]*types.Header, int) (int, error)
+
+	// Rollback removes a few recently added elements from the local chain.
+	Rollback([]common.Hash)
+}
+
+// BlockChain encapsulates functions required to sync a (full or fast) blockchain.
+type BlockChain interface {
+	LightChain
+
+	// HasBlockAndState verifies block and associated states' presence in the local chain.
+	HasBlockAndState(common.Hash) bool
+
+	// GetBlockByHash retrieves a block from the local chain.
+	GetBlockByHash(common.Hash) *types.Block
+
+	// CurrentBlock retrieves the head block from the local chain.
+	CurrentBlock() *types.Block
+
+	// CurrentFastBlock retrieves the head fast block from the local chain.
+	CurrentFastBlock() *types.Block
+
+	// FastSyncCommitHead directly commits the head block to a certain entity.
+	FastSyncCommitHead(common.Hash) error
+
+	// InsertChain inserts a batch of blocks into the local chain.
+	InsertChain(types.Blocks) (int, error)
+
+	// InsertReceiptChain inserts a batch of receipts into the local chain.
+	InsertReceiptChain(types.Blocks, []types.Receipts) (int, error)
+}
+
 // New creates a new downloader to fetch hashes and blocks from remote peers.
-func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlockAndState blockAndStateCheckFn,
-	getHeader headerRetrievalFn, getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn,
-	headFastBlock headFastBlockRetrievalFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn,
-	insertBlocks blockChainInsertFn, insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader {
+func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn) *Downloader {
+	if lightchain == nil {
+		lightchain = chain
+	}
 
 	dl := &Downloader{
-		mode:             mode,
-		mux:              mux,
-		queue:            newQueue(),
-		peers:            newPeerSet(),
-		stateDB:          stateDb,
-		rttEstimate:      uint64(rttMaxEstimate),
-		rttConfidence:    uint64(1000000),
-		hasHeader:        hasHeader,
-		hasBlockAndState: hasBlockAndState,
-		getHeader:        getHeader,
-		getBlock:         getBlock,
-		headHeader:       headHeader,
-		headBlock:        headBlock,
-		headFastBlock:    headFastBlock,
-		commitHeadBlock:  commitHeadBlock,
-		getTd:            getTd,
-		insertHeaders:    insertHeaders,
-		insertBlocks:     insertBlocks,
-		insertReceipts:   insertReceipts,
-		rollback:         rollback,
-		dropPeer:         dropPeer,
-		headerCh:         make(chan dataPack, 1),
-		bodyCh:           make(chan dataPack, 1),
-		receiptCh:        make(chan dataPack, 1),
-		bodyWakeCh:       make(chan bool, 1),
-		receiptWakeCh:    make(chan bool, 1),
-		headerProcCh:     make(chan []*types.Header, 1),
-		quitCh:           make(chan struct{}),
-		// for stateFetcher
+		mode:           mode,
+		stateDB:        stateDb,
+		mux:            mux,
+		queue:          newQueue(),
+		peers:          newPeerSet(),
+		rttEstimate:    uint64(rttMaxEstimate),
+		rttConfidence:  uint64(1000000),
+		blockchain:     chain,
+		lightchain:     lightchain,
+		dropPeer:       dropPeer,
+		headerCh:       make(chan dataPack, 1),
+		bodyCh:         make(chan dataPack, 1),
+		receiptCh:      make(chan dataPack, 1),
+		bodyWakeCh:     make(chan bool, 1),
+		receiptWakeCh:  make(chan bool, 1),
+		headerProcCh:   make(chan []*types.Header, 1),
+		quitCh:         make(chan struct{}),
+		stateCh:        make(chan dataPack),
 		stateSyncStart: make(chan *stateSync),
 		trackStateReq:  make(chan *stateReq),
-		stateCh:        make(chan dataPack),
 	}
 	go dl.qosTuner()
 	go dl.stateFetcher()
@@ -223,11 +248,11 @@ func (d *Downloader) Progress() ethereum.SyncProgress {
 	current := uint64(0)
 	switch d.mode {
 	case FullSync:
-		current = d.headBlock().NumberU64()
+		current = d.blockchain.CurrentBlock().NumberU64()
 	case FastSync:
-		current = d.headFastBlock().NumberU64()
+		current = d.blockchain.CurrentFastBlock().NumberU64()
 	case LightSync:
-		current = d.headHeader().Number.Uint64()
+		current = d.lightchain.CurrentHeader().Number.Uint64()
 	}
 	return ethereum.SyncProgress{
 		StartingBlock: d.syncStatsChainOrigin,
@@ -245,13 +270,11 @@ func (d *Downloader) Synchronising() bool {
 
 // RegisterPeer injects a new download peer into the set of block source to be
 // used for fetching hashes and blocks from.
-func (d *Downloader) RegisterPeer(id string, version int, currentHead currentHeadRetrievalFn,
-	getRelHeaders relativeHeaderFetcherFn, getAbsHeaders absoluteHeaderFetcherFn, getBlockBodies blockBodyFetcherFn,
-	getReceipts receiptFetcherFn, getNodeData stateFetcherFn) error {
+func (d *Downloader) RegisterPeer(id string, version int, peer Peer) error {
 
 	logger := log.New("peer", id)
 	logger.Trace("Registering sync peer")
-	if err := d.peers.Register(newPeer(id, version, currentHead, getRelHeaders, getAbsHeaders, getBlockBodies, getReceipts, getNodeData, logger)); err != nil {
+	if err := d.peers.Register(newPeerConnection(id, version, peer, logger)); err != nil {
 		logger.Error("Failed to register sync peer", "err", err)
 		return err
 	}
@@ -260,6 +283,11 @@ func (d *Downloader) RegisterPeer(id string, version int, currentHead currentHea
 	return nil
 }
 
+// RegisterLightPeer injects a light client peer, wrapping it so it appears as a regular peer.
+func (d *Downloader) RegisterLightPeer(id string, version int, peer LightPeer) error {
+	return d.RegisterPeer(id, version, &lightPeerWrapper{peer})
+}
+
 // UnregisterPeer remove a peer from the known list, preventing any action from
 // the specified peer. An effort is also made to return any pending fetches into
 // the queue.
@@ -371,7 +399,7 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode
 
 // syncWithPeer starts a block synchronization based on the hash chain from the
 // specified peer and head hash.
-func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err error) {
+func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.Int) (err error) {
 	d.mux.Post(StartEvent{})
 	defer func() {
 		// reset on error
@@ -524,12 +552,12 @@ func (d *Downloader) Terminate() {
 
 // fetchHeight retrieves the head header of the remote peer to aid in estimating
 // the total time a pending synchronisation would take.
-func (d *Downloader) fetchHeight(p *peer) (*types.Header, error) {
+func (d *Downloader) fetchHeight(p *peerConnection) (*types.Header, error) {
 	p.log.Debug("Retrieving remote chain height")
 
 	// Request the advertised remote head block and wait for the response
-	head, _ := p.currentHead()
-	go p.getRelHeaders(head, 1, 0, false)
+	head, _ := p.peer.Head()
+	go p.peer.RequestHeadersByHash(head, 1, 0, false)
 
 	ttl := d.requestTTL()
 	timeout := time.After(ttl)
@@ -570,15 +598,15 @@ func (d *Downloader) fetchHeight(p *peer) (*types.Header, error) {
 // on the correct chain, checking the top N links should already get us a match.
 // In the rare scenario when we ended up on a long reorganisation (i.e. none of
 // the head links match), we do a binary search to find the common ancestor.
-func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) {
+func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, error) {
 	// Figure out the valid ancestor range to prevent rewrite attacks
-	floor, ceil := int64(-1), d.headHeader().Number.Uint64()
+	floor, ceil := int64(-1), d.lightchain.CurrentHeader().Number.Uint64()
 
 	p.log.Debug("Looking for common ancestor", "local", ceil, "remote", height)
 	if d.mode == FullSync {
-		ceil = d.headBlock().NumberU64()
+		ceil = d.blockchain.CurrentBlock().NumberU64()
 	} else if d.mode == FastSync {
-		ceil = d.headFastBlock().NumberU64()
+		ceil = d.blockchain.CurrentFastBlock().NumberU64()
 	}
 	if ceil >= MaxForkAncestry {
 		floor = int64(ceil - MaxForkAncestry)
@@ -598,7 +626,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) {
 	if count > limit {
 		count = limit
 	}
-	go p.getAbsHeaders(uint64(from), count, 15, false)
+	go p.peer.RequestHeadersByNumber(uint64(from), count, 15, false)
 
 	// Wait for the remote response to the head fetch
 	number, hash := uint64(0), common.Hash{}
@@ -638,7 +666,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) {
 					continue
 				}
 				// Otherwise check if we already know the header or not
-				if (d.mode == FullSync && d.hasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.hasHeader(headers[i].Hash())) {
+				if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash())) {
 					number, hash = headers[i].Number.Uint64(), headers[i].Hash()
 
 					// If every header is known, even future ones, the peer straight out lied about its head
@@ -680,7 +708,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) {
 		ttl := d.requestTTL()
 		timeout := time.After(ttl)
 
-		go p.getAbsHeaders(uint64(check), 1, 0, false)
+		go p.peer.RequestHeadersByNumber(uint64(check), 1, 0, false)
 
 		// Wait until a reply arrives to this request
 		for arrived := false; !arrived; {
@@ -703,11 +731,11 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) {
 				arrived = true
 
 				// Modify the search interval based on the response
-				if (d.mode == FullSync && !d.hasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.hasHeader(headers[0].Hash())) {
+				if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash())) {
 					end = check
 					break
 				}
-				header := d.getHeader(headers[0].Hash()) // Independent of sync mode, header surely exists
+				header := d.lightchain.GetHeaderByHash(headers[0].Hash()) // Independent of sync mode, header surely exists
 				if header.Number.Uint64() != check {
 					p.log.Debug("Received non requested header", "number", header.Number, "hash", header.Hash(), "request", check)
 					return 0, errBadPeer
@@ -741,7 +769,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) {
 // other peers are only accepted if they map cleanly to the skeleton. If no one
 // can fill in the skeleton - not even the origin peer - it's assumed invalid and
 // the origin is dropped.
-func (d *Downloader) fetchHeaders(p *peer, from uint64) error {
+func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error {
 	p.log.Debug("Directing header downloads", "origin", from)
 	defer p.log.Debug("Header download terminated")
 
@@ -761,10 +789,10 @@ func (d *Downloader) fetchHeaders(p *peer, from uint64) error {
 
 		if skeleton {
 			p.log.Trace("Fetching skeleton headers", "count", MaxHeaderFetch, "from", from)
-			go p.getAbsHeaders(from+uint64(MaxHeaderFetch)-1, MaxSkeletonSize, MaxHeaderFetch-1, false)
+			go p.peer.RequestHeadersByNumber(from+uint64(MaxHeaderFetch)-1, MaxSkeletonSize, MaxHeaderFetch-1, false)
 		} else {
 			p.log.Trace("Fetching full headers", "count", MaxHeaderFetch, "from", from)
-			go p.getAbsHeaders(from, MaxHeaderFetch, 0, false)
+			go p.peer.RequestHeadersByNumber(from, MaxHeaderFetch, 0, false)
 		}
 	}
 	// Start pulling the header chain skeleton until all is done
@@ -866,12 +894,12 @@ func (d *Downloader) fillHeaderSkeleton(from uint64, skeleton []*types.Header) (
 		}
 		expire   = func() map[string]int { return d.queue.ExpireHeaders(d.requestTTL()) }
 		throttle = func() bool { return false }
-		reserve  = func(p *peer, count int) (*fetchRequest, bool, error) {
+		reserve  = func(p *peerConnection, count int) (*fetchRequest, bool, error) {
 			return d.queue.ReserveHeaders(p, count), false, nil
 		}
-		fetch    = func(p *peer, req *fetchRequest) error { return p.FetchHeaders(req.From, MaxHeaderFetch) }
-		capacity = func(p *peer) int { return p.HeaderCapacity(d.requestRTT()) }
-		setIdle  = func(p *peer, accepted int) { p.SetHeadersIdle(accepted) }
+		fetch    = func(p *peerConnection, req *fetchRequest) error { return p.FetchHeaders(req.From, MaxHeaderFetch) }
+		capacity = func(p *peerConnection) int { return p.HeaderCapacity(d.requestRTT()) }
+		setIdle  = func(p *peerConnection, accepted int) { p.SetHeadersIdle(accepted) }
 	)
 	err := d.fetchParts(errCancelHeaderFetch, d.headerCh, deliver, d.queue.headerContCh, expire,
 		d.queue.PendingHeaders, d.queue.InFlightHeaders, throttle, reserve,
@@ -895,9 +923,9 @@ func (d *Downloader) fetchBodies(from uint64) error {
 			return d.queue.DeliverBodies(pack.peerId, pack.transactions, pack.uncles)
 		}
 		expire   = func() map[string]int { return d.queue.ExpireBodies(d.requestTTL()) }
-		fetch    = func(p *peer, req *fetchRequest) error { return p.FetchBodies(req) }
-		capacity = func(p *peer) int { return p.BlockCapacity(d.requestRTT()) }
-		setIdle  = func(p *peer, accepted int) { p.SetBodiesIdle(accepted) }
+		fetch    = func(p *peerConnection, req *fetchRequest) error { return p.FetchBodies(req) }
+		capacity = func(p *peerConnection) int { return p.BlockCapacity(d.requestRTT()) }
+		setIdle  = func(p *peerConnection, accepted int) { p.SetBodiesIdle(accepted) }
 	)
 	err := d.fetchParts(errCancelBodyFetch, d.bodyCh, deliver, d.bodyWakeCh, expire,
 		d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ShouldThrottleBlocks, d.queue.ReserveBodies,
@@ -919,9 +947,9 @@ func (d *Downloader) fetchReceipts(from uint64) error {
 			return d.queue.DeliverReceipts(pack.peerId, pack.receipts)
 		}
 		expire   = func() map[string]int { return d.queue.ExpireReceipts(d.requestTTL()) }
-		fetch    = func(p *peer, req *fetchRequest) error { return p.FetchReceipts(req) }
-		capacity = func(p *peer) int { return p.ReceiptCapacity(d.requestRTT()) }
-		setIdle  = func(p *peer, accepted int) { p.SetReceiptsIdle(accepted) }
+		fetch    = func(p *peerConnection, req *fetchRequest) error { return p.FetchReceipts(req) }
+		capacity = func(p *peerConnection) int { return p.ReceiptCapacity(d.requestRTT()) }
+		setIdle  = func(p *peerConnection, accepted int) { p.SetReceiptsIdle(accepted) }
 	)
 	err := d.fetchParts(errCancelReceiptFetch, d.receiptCh, deliver, d.receiptWakeCh, expire,
 		d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ShouldThrottleReceipts, d.queue.ReserveReceipts,
@@ -957,9 +985,9 @@ func (d *Downloader) fetchReceipts(from uint64) error {
 //  - setIdle:     network callback to set a peer back to idle and update its estimated capacity (traffic shaping)
 //  - kind:        textual label of the type being downloaded to display in log mesages
 func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool,
-	expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peer, int) (*fetchRequest, bool, error),
-	fetchHook func([]*types.Header), fetch func(*peer, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peer) int,
-	idle func() ([]*peer, int), setIdle func(*peer, int), kind string) error {
+	expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, error),
+	fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int,
+	idle func() ([]*peerConnection, int), setIdle func(*peerConnection, int), kind string) error {
 
 	// Create a ticker to detect expired retrieval tasks
 	ticker := time.NewTicker(100 * time.Millisecond)
@@ -1124,23 +1152,19 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
 			for i, header := range rollback {
 				hashes[i] = header.Hash()
 			}
-			lastHeader, lastFastBlock, lastBlock := d.headHeader().Number, common.Big0, common.Big0
-			if d.headFastBlock != nil {
-				lastFastBlock = d.headFastBlock().Number()
+			lastHeader, lastFastBlock, lastBlock := d.lightchain.CurrentHeader().Number, common.Big0, common.Big0
+			if d.mode != LightSync {
+				lastFastBlock = d.blockchain.CurrentFastBlock().Number()
+				lastBlock = d.blockchain.CurrentBlock().Number()
 			}
-			if d.headBlock != nil {
-				lastBlock = d.headBlock().Number()
-			}
-			d.rollback(hashes)
+			d.lightchain.Rollback(hashes)
 			curFastBlock, curBlock := common.Big0, common.Big0
-			if d.headFastBlock != nil {
-				curFastBlock = d.headFastBlock().Number()
-			}
-			if d.headBlock != nil {
-				curBlock = d.headBlock().Number()
+			if d.mode != LightSync {
+				curFastBlock = d.blockchain.CurrentFastBlock().Number()
+				curBlock = d.blockchain.CurrentBlock().Number()
 			}
 			log.Warn("Rolled back headers", "count", len(hashes),
-				"header", fmt.Sprintf("%d->%d", lastHeader, d.headHeader().Number),
+				"header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number),
 				"fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock),
 				"block", fmt.Sprintf("%d->%d", lastBlock, curBlock))
 
@@ -1190,7 +1214,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
 				// L: Request new headers up from 11 (R's TD was higher, it must have something)
 				// R: Nothing to give
 				if d.mode != LightSync {
-					if !gotHeaders && td.Cmp(d.getTd(d.headBlock().Hash())) > 0 {
+					if !gotHeaders && td.Cmp(d.blockchain.GetTdByHash(d.blockchain.CurrentBlock().Hash())) > 0 {
 						return errStallingPeer
 					}
 				}
@@ -1202,7 +1226,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
 				// queued for processing when the header download completes. However, as long as the
 				// peer gave us something useful, we're already happy/progressed (above check).
 				if d.mode == FastSync || d.mode == LightSync {
-					if td.Cmp(d.getTd(d.headHeader().Hash())) > 0 {
+					if td.Cmp(d.lightchain.GetTdByHash(d.lightchain.CurrentHeader().Hash())) > 0 {
 						return errStallingPeer
 					}
 				}
@@ -1232,7 +1256,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
 					// Collect the yet unknown headers to mark them as uncertain
 					unknown := make([]*types.Header, 0, len(headers))
 					for _, header := range chunk {
-						if !d.hasHeader(header.Hash()) {
+						if !d.lightchain.HasHeader(header.Hash()) {
 							unknown = append(unknown, header)
 						}
 					}
@@ -1241,7 +1265,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
 					if chunk[len(chunk)-1].Number.Uint64()+uint64(fsHeaderForceVerify) > pivot {
 						frequency = 1
 					}
-					if n, err := d.insertHeaders(chunk, frequency); err != nil {
+					if n, err := d.lightchain.InsertHeaderChain(chunk, frequency); err != nil {
 						// If some headers were inserted, add them too to the rollback list
 						if n > 0 {
 							rollback = append(rollback, chunk[:n]...)
@@ -1328,7 +1352,7 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error {
 		for i, result := range results[:items] {
 			blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
 		}
-		if index, err := d.insertBlocks(blocks); err != nil {
+		if index, err := d.blockchain.InsertChain(blocks); err != nil {
 			log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
 			return errInvalidChain
 		}
@@ -1368,6 +1392,7 @@ func (d *Downloader) processFastSyncContent(latest *types.Header) error {
 			stateSync.Cancel()
 			if err := d.commitPivotBlock(P); err != nil {
 				return err
+
 			}
 		}
 		if err := d.importBlockResults(afterP); err != nil {
@@ -1416,7 +1441,7 @@ func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *state
 			blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
 			receipts[i] = result.Receipts
 		}
-		if index, err := d.insertReceipts(blocks, receipts); err != nil {
+		if index, err := d.blockchain.InsertReceiptChain(blocks, receipts); err != nil {
 			log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
 			return errInvalidChain
 		}
@@ -1434,10 +1459,10 @@ func (d *Downloader) commitPivotBlock(result *fetchResult) error {
 		return err
 	}
 	log.Debug("Committing fast sync pivot as new head", "number", b.Number(), "hash", b.Hash())
-	if _, err := d.insertReceipts([]*types.Block{b}, []types.Receipts{result.Receipts}); err != nil {
+	if _, err := d.blockchain.InsertReceiptChain([]*types.Block{b}, []types.Receipts{result.Receipts}); err != nil {
 		return err
 	}
-	return d.commitHeadBlock(b.Hash())
+	return d.blockchain.FastSyncCommitHead(b.Hash())
 }
 
 // DeliverHeaders injects a new batch of block headers received from a remote

+ 171 - 163
eth/downloader/downloader_test.go

@@ -96,9 +96,7 @@ func newTester() *downloadTester {
 	tester.stateDb, _ = ethdb.NewMemDatabase()
 	tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00})
 
-	tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
-		tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
-		tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
+	tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester, nil, tester.dropPeer)
 
 	return tester
 }
@@ -218,14 +216,14 @@ func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error {
 	return err
 }
 
-// hasHeader checks if a header is present in the testers canonical chain.
-func (dl *downloadTester) hasHeader(hash common.Hash) bool {
-	return dl.getHeader(hash) != nil
+// HasHeader checks if a header is present in the testers canonical chain.
+func (dl *downloadTester) HasHeader(hash common.Hash) bool {
+	return dl.GetHeaderByHash(hash) != nil
 }
 
-// hasBlock checks if a block and associated state is present in the testers canonical chain.
-func (dl *downloadTester) hasBlock(hash common.Hash) bool {
-	block := dl.getBlock(hash)
+// HasBlockAndState checks if a block and associated state is present in the testers canonical chain.
+func (dl *downloadTester) HasBlockAndState(hash common.Hash) bool {
+	block := dl.GetBlockByHash(hash)
 	if block == nil {
 		return false
 	}
@@ -233,24 +231,24 @@ func (dl *downloadTester) hasBlock(hash common.Hash) bool {
 	return err == nil
 }
 
-// getHeader retrieves a header from the testers canonical chain.
-func (dl *downloadTester) getHeader(hash common.Hash) *types.Header {
+// GetHeader retrieves a header from the testers canonical chain.
+func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header {
 	dl.lock.RLock()
 	defer dl.lock.RUnlock()
 
 	return dl.ownHeaders[hash]
 }
 
-// getBlock retrieves a block from the testers canonical chain.
-func (dl *downloadTester) getBlock(hash common.Hash) *types.Block {
+// GetBlock retrieves a block from the testers canonical chain.
+func (dl *downloadTester) GetBlockByHash(hash common.Hash) *types.Block {
 	dl.lock.RLock()
 	defer dl.lock.RUnlock()
 
 	return dl.ownBlocks[hash]
 }
 
-// headHeader retrieves the current head header from the canonical chain.
-func (dl *downloadTester) headHeader() *types.Header {
+// CurrentHeader retrieves the current head header from the canonical chain.
+func (dl *downloadTester) CurrentHeader() *types.Header {
 	dl.lock.RLock()
 	defer dl.lock.RUnlock()
 
@@ -262,8 +260,8 @@ func (dl *downloadTester) headHeader() *types.Header {
 	return dl.genesis.Header()
 }
 
-// headBlock retrieves the current head block from the canonical chain.
-func (dl *downloadTester) headBlock() *types.Block {
+// CurrentBlock retrieves the current head block from the canonical chain.
+func (dl *downloadTester) CurrentBlock() *types.Block {
 	dl.lock.RLock()
 	defer dl.lock.RUnlock()
 
@@ -277,8 +275,8 @@ func (dl *downloadTester) headBlock() *types.Block {
 	return dl.genesis
 }
 
-// headFastBlock retrieves the current head fast-sync block from the canonical chain.
-func (dl *downloadTester) headFastBlock() *types.Block {
+// CurrentFastBlock retrieves the current head fast-sync block from the canonical chain.
+func (dl *downloadTester) CurrentFastBlock() *types.Block {
 	dl.lock.RLock()
 	defer dl.lock.RUnlock()
 
@@ -290,26 +288,26 @@ func (dl *downloadTester) headFastBlock() *types.Block {
 	return dl.genesis
 }
 
-// commitHeadBlock manually sets the head block to a given hash.
-func (dl *downloadTester) commitHeadBlock(hash common.Hash) error {
+// FastSyncCommitHead manually sets the head block to a given hash.
+func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error {
 	// For now only check that the state trie is correct
-	if block := dl.getBlock(hash); block != nil {
+	if block := dl.GetBlockByHash(hash); block != nil {
 		_, err := trie.NewSecure(block.Root(), dl.stateDb, 0)
 		return err
 	}
 	return fmt.Errorf("non existent block: %x", hash[:4])
 }
 
-// getTd retrieves the block's total difficulty from the canonical chain.
-func (dl *downloadTester) getTd(hash common.Hash) *big.Int {
+// GetTdByHash retrieves the block's total difficulty from the canonical chain.
+func (dl *downloadTester) GetTdByHash(hash common.Hash) *big.Int {
 	dl.lock.RLock()
 	defer dl.lock.RUnlock()
 
 	return dl.ownChainTd[hash]
 }
 
-// insertHeaders injects a new batch of headers into the simulated chain.
-func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) (int, error) {
+// InsertHeaderChain injects a new batch of headers into the simulated chain.
+func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq int) (int, error) {
 	dl.lock.Lock()
 	defer dl.lock.Unlock()
 
@@ -337,8 +335,8 @@ func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int)
 	return len(headers), nil
 }
 
-// insertBlocks injects a new batch of blocks into the simulated chain.
-func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
+// InsertChain injects a new batch of blocks into the simulated chain.
+func (dl *downloadTester) InsertChain(blocks types.Blocks) (int, error) {
 	dl.lock.Lock()
 	defer dl.lock.Unlock()
 
@@ -359,8 +357,8 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
 	return len(blocks), nil
 }
 
-// insertReceipts injects a new batch of receipts into the simulated chain.
-func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.Receipts) (int, error) {
+// InsertReceiptChain injects a new batch of receipts into the simulated chain.
+func (dl *downloadTester) InsertReceiptChain(blocks types.Blocks, receipts []types.Receipts) (int, error) {
 	dl.lock.Lock()
 	defer dl.lock.Unlock()
 
@@ -377,8 +375,8 @@ func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.R
 	return len(blocks), nil
 }
 
-// rollback removes some recently added elements from the chain.
-func (dl *downloadTester) rollback(hashes []common.Hash) {
+// Rollback removes some recently added elements from the chain.
+func (dl *downloadTester) Rollback(hashes []common.Hash) {
 	dl.lock.Lock()
 	defer dl.lock.Unlock()
 
@@ -406,14 +404,7 @@ func (dl *downloadTester) newSlowPeer(id string, version int, hashes []common.Ha
 	defer dl.lock.Unlock()
 
 	var err error
-	switch version {
-	case 62:
-		err = dl.downloader.RegisterPeer(id, version, dl.peerCurrentHeadFn(id), dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), nil, nil)
-	case 63:
-		err = dl.downloader.RegisterPeer(id, version, dl.peerCurrentHeadFn(id), dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay), dl.peerGetNodeDataFn(id, delay))
-	case 64:
-		err = dl.downloader.RegisterPeer(id, version, dl.peerCurrentHeadFn(id), dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay), dl.peerGetNodeDataFn(id, delay))
-	}
+	err = dl.downloader.RegisterPeer(id, version, &downloadTesterPeer{dl, id, delay})
 	if err == nil {
 		// Assign the owned hashes, headers and blocks to the peer (deep copy)
 		dl.peerHashes[id] = make([]common.Hash, len(hashes))
@@ -471,139 +462,133 @@ func (dl *downloadTester) dropPeer(id string) {
 	dl.downloader.UnregisterPeer(id)
 }
 
-// peerCurrentHeadFn constructs a function to retrieve a peer's current head hash
+type downloadTesterPeer struct {
+	dl    *downloadTester
+	id    string
+	delay time.Duration
+}
+
+// Head constructs a function to retrieve a peer's current head hash
 // and total difficulty.
-func (dl *downloadTester) peerCurrentHeadFn(id string) func() (common.Hash, *big.Int) {
-	return func() (common.Hash, *big.Int) {
-		dl.lock.RLock()
-		defer dl.lock.RUnlock()
+func (dlp *downloadTesterPeer) Head() (common.Hash, *big.Int) {
+	dlp.dl.lock.RLock()
+	defer dlp.dl.lock.RUnlock()
 
-		return dl.peerHashes[id][0], nil
-	}
+	return dlp.dl.peerHashes[dlp.id][0], nil
 }
 
-// peerGetRelHeadersFn constructs a GetBlockHeaders function based on a hashed
+// RequestHeadersByHash constructs a GetBlockHeaders function based on a hashed
 // origin; associated with a particular peer in the download tester. The returned
 // function can be used to retrieve batches of headers from the particular peer.
-func (dl *downloadTester) peerGetRelHeadersFn(id string, delay time.Duration) func(common.Hash, int, int, bool) error {
-	return func(origin common.Hash, amount int, skip int, reverse bool) error {
-		// Find the canonical number of the hash
-		dl.lock.RLock()
-		number := uint64(0)
-		for num, hash := range dl.peerHashes[id] {
-			if hash == origin {
-				number = uint64(len(dl.peerHashes[id]) - num - 1)
-				break
-			}
+func (dlp *downloadTesterPeer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
+	// Find the canonical number of the hash
+	dlp.dl.lock.RLock()
+	number := uint64(0)
+	for num, hash := range dlp.dl.peerHashes[dlp.id] {
+		if hash == origin {
+			number = uint64(len(dlp.dl.peerHashes[dlp.id]) - num - 1)
+			break
 		}
-		dl.lock.RUnlock()
-
-		// Use the absolute header fetcher to satisfy the query
-		return dl.peerGetAbsHeadersFn(id, delay)(number, amount, skip, reverse)
 	}
+	dlp.dl.lock.RUnlock()
+
+	// Use the absolute header fetcher to satisfy the query
+	return dlp.RequestHeadersByNumber(number, amount, skip, reverse)
 }
 
-// peerGetAbsHeadersFn constructs a GetBlockHeaders function based on a numbered
+// RequestHeadersByNumber constructs a GetBlockHeaders function based on a numbered
 // origin; associated with a particular peer in the download tester. The returned
 // function can be used to retrieve batches of headers from the particular peer.
-func (dl *downloadTester) peerGetAbsHeadersFn(id string, delay time.Duration) func(uint64, int, int, bool) error {
-	return func(origin uint64, amount int, skip int, reverse bool) error {
-		time.Sleep(delay)
-
-		dl.lock.RLock()
-		defer dl.lock.RUnlock()
-
-		// Gather the next batch of headers
-		hashes := dl.peerHashes[id]
-		headers := dl.peerHeaders[id]
-		result := make([]*types.Header, 0, amount)
-		for i := 0; i < amount && len(hashes)-int(origin)-1-i*(skip+1) >= 0; i++ {
-			if header, ok := headers[hashes[len(hashes)-int(origin)-1-i*(skip+1)]]; ok {
-				result = append(result, header)
-			}
+func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
+	time.Sleep(dlp.delay)
+
+	dlp.dl.lock.RLock()
+	defer dlp.dl.lock.RUnlock()
+
+	// Gather the next batch of headers
+	hashes := dlp.dl.peerHashes[dlp.id]
+	headers := dlp.dl.peerHeaders[dlp.id]
+	result := make([]*types.Header, 0, amount)
+	for i := 0; i < amount && len(hashes)-int(origin)-1-i*(skip+1) >= 0; i++ {
+		if header, ok := headers[hashes[len(hashes)-int(origin)-1-i*(skip+1)]]; ok {
+			result = append(result, header)
 		}
-		// Delay delivery a bit to allow attacks to unfold
-		go func() {
-			time.Sleep(time.Millisecond)
-			dl.downloader.DeliverHeaders(id, result)
-		}()
-		return nil
 	}
+	// Delay delivery a bit to allow attacks to unfold
+	go func() {
+		time.Sleep(time.Millisecond)
+		dlp.dl.downloader.DeliverHeaders(dlp.id, result)
+	}()
+	return nil
 }
 
-// peerGetBodiesFn constructs a getBlockBodies method associated with a particular
+// RequestBodies constructs a getBlockBodies method associated with a particular
 // peer in the download tester. The returned function can be used to retrieve
 // batches of block bodies from the particularly requested peer.
-func (dl *downloadTester) peerGetBodiesFn(id string, delay time.Duration) func([]common.Hash) error {
-	return func(hashes []common.Hash) error {
-		time.Sleep(delay)
+func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash) error {
+	time.Sleep(dlp.delay)
 
-		dl.lock.RLock()
-		defer dl.lock.RUnlock()
+	dlp.dl.lock.RLock()
+	defer dlp.dl.lock.RUnlock()
 
-		blocks := dl.peerBlocks[id]
+	blocks := dlp.dl.peerBlocks[dlp.id]
 
-		transactions := make([][]*types.Transaction, 0, len(hashes))
-		uncles := make([][]*types.Header, 0, len(hashes))
+	transactions := make([][]*types.Transaction, 0, len(hashes))
+	uncles := make([][]*types.Header, 0, len(hashes))
 
-		for _, hash := range hashes {
-			if block, ok := blocks[hash]; ok {
-				transactions = append(transactions, block.Transactions())
-				uncles = append(uncles, block.Uncles())
-			}
+	for _, hash := range hashes {
+		if block, ok := blocks[hash]; ok {
+			transactions = append(transactions, block.Transactions())
+			uncles = append(uncles, block.Uncles())
 		}
-		go dl.downloader.DeliverBodies(id, transactions, uncles)
-
-		return nil
 	}
+	go dlp.dl.downloader.DeliverBodies(dlp.id, transactions, uncles)
+
+	return nil
 }
 
-// peerGetReceiptsFn constructs a getReceipts method associated with a particular
+// RequestReceipts constructs a getReceipts method associated with a particular
 // peer in the download tester. The returned function can be used to retrieve
 // batches of block receipts from the particularly requested peer.
-func (dl *downloadTester) peerGetReceiptsFn(id string, delay time.Duration) func([]common.Hash) error {
-	return func(hashes []common.Hash) error {
-		time.Sleep(delay)
+func (dlp *downloadTesterPeer) RequestReceipts(hashes []common.Hash) error {
+	time.Sleep(dlp.delay)
 
-		dl.lock.RLock()
-		defer dl.lock.RUnlock()
+	dlp.dl.lock.RLock()
+	defer dlp.dl.lock.RUnlock()
 
-		receipts := dl.peerReceipts[id]
+	receipts := dlp.dl.peerReceipts[dlp.id]
 
-		results := make([][]*types.Receipt, 0, len(hashes))
-		for _, hash := range hashes {
-			if receipt, ok := receipts[hash]; ok {
-				results = append(results, receipt)
-			}
+	results := make([][]*types.Receipt, 0, len(hashes))
+	for _, hash := range hashes {
+		if receipt, ok := receipts[hash]; ok {
+			results = append(results, receipt)
 		}
-		go dl.downloader.DeliverReceipts(id, results)
-
-		return nil
 	}
+	go dlp.dl.downloader.DeliverReceipts(dlp.id, results)
+
+	return nil
 }
 
-// peerGetNodeDataFn constructs a getNodeData method associated with a particular
+// RequestNodeData constructs a getNodeData method associated with a particular
 // peer in the download tester. The returned function can be used to retrieve
 // batches of node state data from the particularly requested peer.
-func (dl *downloadTester) peerGetNodeDataFn(id string, delay time.Duration) func([]common.Hash) error {
-	return func(hashes []common.Hash) error {
-		time.Sleep(delay)
-
-		dl.lock.RLock()
-		defer dl.lock.RUnlock()
-
-		results := make([][]byte, 0, len(hashes))
-		for _, hash := range hashes {
-			if data, err := dl.peerDb.Get(hash.Bytes()); err == nil {
-				if !dl.peerMissingStates[id][hash] {
-					results = append(results, data)
-				}
+func (dlp *downloadTesterPeer) RequestNodeData(hashes []common.Hash) error {
+	time.Sleep(dlp.delay)
+
+	dlp.dl.lock.RLock()
+	defer dlp.dl.lock.RUnlock()
+
+	results := make([][]byte, 0, len(hashes))
+	for _, hash := range hashes {
+		if data, err := dlp.dl.peerDb.Get(hash.Bytes()); err == nil {
+			if !dlp.dl.peerMissingStates[dlp.id][hash] {
+				results = append(results, data)
 			}
 		}
-		go dl.downloader.DeliverNodeData(id, results)
-
-		return nil
 	}
+	go dlp.dl.downloader.DeliverNodeData(dlp.id, results)
+
+	return nil
 }
 
 // assertOwnChain checks if the local chain contains the correct number of items
@@ -1212,7 +1197,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
 	if err := tester.sync("fast-attack", nil, mode); err == nil {
 		t.Fatalf("succeeded fast attacker synchronisation")
 	}
-	if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch {
+	if head := tester.CurrentHeader().Number.Int64(); int(head) > MaxHeaderFetch {
 		t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch)
 	}
 	// Attempt to sync with an attacker that feeds junk during the block import phase.
@@ -1226,11 +1211,11 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
 	if err := tester.sync("block-attack", nil, mode); err == nil {
 		t.Fatalf("succeeded block attacker synchronisation")
 	}
-	if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
+	if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
 		t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
 	}
 	if mode == FastSync {
-		if head := tester.headBlock().NumberU64(); head != 0 {
+		if head := tester.CurrentBlock().NumberU64(); head != 0 {
 			t.Errorf("fast sync pivot block #%d not rolled back", head)
 		}
 	}
@@ -1251,11 +1236,11 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
 	if err := tester.sync("withhold-attack", nil, mode); err == nil {
 		t.Fatalf("succeeded withholding attacker synchronisation")
 	}
-	if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
+	if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
 		t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
 	}
 	if mode == FastSync {
-		if head := tester.headBlock().NumberU64(); head != 0 {
+		if head := tester.CurrentBlock().NumberU64(); head != 0 {
 			t.Errorf("fast sync pivot block #%d not rolled back", head)
 		}
 	}
@@ -1670,6 +1655,48 @@ func TestDeliverHeadersHang64Full(t *testing.T)  { testDeliverHeadersHang(t, 64,
 func TestDeliverHeadersHang64Fast(t *testing.T)  { testDeliverHeadersHang(t, 64, FastSync) }
 func TestDeliverHeadersHang64Light(t *testing.T) { testDeliverHeadersHang(t, 64, LightSync) }
 
+type floodingTestPeer struct {
+	peer   Peer
+	tester *downloadTester
+}
+
+func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() }
+func (ftp *floodingTestPeer) RequestHeadersByHash(hash common.Hash, count int, skip int, reverse bool) error {
+	return ftp.peer.RequestHeadersByHash(hash, count, skip, reverse)
+}
+func (ftp *floodingTestPeer) RequestBodies(hashes []common.Hash) error {
+	return ftp.peer.RequestBodies(hashes)
+}
+func (ftp *floodingTestPeer) RequestReceipts(hashes []common.Hash) error {
+	return ftp.peer.RequestReceipts(hashes)
+}
+func (ftp *floodingTestPeer) RequestNodeData(hashes []common.Hash) error {
+	return ftp.peer.RequestNodeData(hashes)
+}
+
+func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int, reverse bool) error {
+	deliveriesDone := make(chan struct{}, 500)
+	for i := 0; i < cap(deliveriesDone); i++ {
+		peer := fmt.Sprintf("fake-peer%d", i)
+		go func() {
+			ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}})
+			deliveriesDone <- struct{}{}
+		}()
+	}
+	// Deliver the actual requested headers.
+	go ftp.peer.RequestHeadersByNumber(from, count, skip, reverse)
+	// None of the extra deliveries should block.
+	timeout := time.After(15 * time.Second)
+	for i := 0; i < cap(deliveriesDone); i++ {
+		select {
+		case <-deliveriesDone:
+		case <-timeout:
+			panic("blocked")
+		}
+	}
+	return nil
+}
+
 func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
 	t.Parallel()
 
@@ -1677,7 +1704,6 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
 	defer master.terminate()
 
 	hashes, headers, blocks, receipts := master.makeChain(5, 0, master.genesis, nil, false)
-	fakeHeads := []*types.Header{{}, {}, {}, {}}
 	for i := 0; i < 200; i++ {
 		tester := newTester()
 		tester.peerDb = master.peerDb
@@ -1685,29 +1711,11 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
 		tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
 		// Whenever the downloader requests headers, flood it with
 		// a lot of unrequested header deliveries.
-		tester.downloader.peers.peers["peer"].getAbsHeaders = func(from uint64, count, skip int, reverse bool) error {
-			deliveriesDone := make(chan struct{}, 500)
-			for i := 0; i < cap(deliveriesDone); i++ {
-				peer := fmt.Sprintf("fake-peer%d", i)
-				go func() {
-					tester.downloader.DeliverHeaders(peer, fakeHeads)
-					deliveriesDone <- struct{}{}
-				}()
-			}
-			// Deliver the actual requested headers.
-			impl := tester.peerGetAbsHeadersFn("peer", 0)
-			go impl(from, count, skip, reverse)
-			// None of the extra deliveries should block.
-			timeout := time.After(15 * time.Second)
-			for i := 0; i < cap(deliveriesDone); i++ {
-				select {
-				case <-deliveriesDone:
-				case <-timeout:
-					panic("blocked")
-				}
-			}
-			return nil
+		tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{
+			tester.downloader.peers.peers["peer"].peer,
+			tester,
 		}
+
 		if err := tester.sync("peer", nil, mode); err != nil {
 			t.Errorf("sync failed: %v", err)
 		}
@@ -1739,7 +1747,7 @@ func testFastCriticalRestarts(t *testing.T, protocol int, progress bool) {
 	for i := 0; i < fsPivotInterval; i++ {
 		tester.peerMissingStates["peer"][headers[hashes[fsMinFullBlocks+i]].Root] = true
 	}
-	tester.downloader.peers.peers["peer"].getNodeData = tester.peerGetNodeDataFn("peer", 500*time.Millisecond) // Enough to reach the critical section
+	(tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).delay = 500 * time.Millisecond // Enough to reach the critical section
 
 	// Synchronise with the peer a few times and make sure they fail until the retry limit
 	for i := 0; i < int(fsCriticalTrials)-1; i++ {
@@ -1758,7 +1766,7 @@ func testFastCriticalRestarts(t *testing.T, protocol int, progress bool) {
 			tester.lock.Lock()
 			tester.peerHeaders["peer"][hashes[fsMinFullBlocks-1]] = headers[hashes[fsMinFullBlocks-1]]
 			tester.peerMissingStates["peer"] = map[common.Hash]bool{tester.downloader.fsPivotLock.Root: true}
-			tester.downloader.peers.peers["peer"].getNodeData = tester.peerGetNodeDataFn("peer", 0)
+			(tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).delay = 0
 			tester.lock.Unlock()
 		}
 	}

+ 87 - 75
eth/downloader/peer.go

@@ -39,24 +39,14 @@ const (
 	measurementImpact = 0.1  // The impact a single measurement has on a peer's final throughput value.
 )
 
-// Head hash and total difficulty retriever for
-type currentHeadRetrievalFn func() (common.Hash, *big.Int)
-
-// Block header and body fetchers belonging to eth/62 and above
-type relativeHeaderFetcherFn func(common.Hash, int, int, bool) error
-type absoluteHeaderFetcherFn func(uint64, int, int, bool) error
-type blockBodyFetcherFn func([]common.Hash) error
-type receiptFetcherFn func([]common.Hash) error
-type stateFetcherFn func([]common.Hash) error
-
 var (
 	errAlreadyFetching   = errors.New("already fetching blocks from peer")
 	errAlreadyRegistered = errors.New("peer is already registered")
 	errNotRegistered     = errors.New("peer is not registered")
 )
 
-// peer represents an active peer from which hashes and blocks are retrieved.
-type peer struct {
+// peerConnection represents an active peer from which hashes and blocks are retrieved.
+type peerConnection struct {
 	id string // Unique identifier of the peer
 
 	headerIdle  int32 // Current header activity state of the peer (idle = 0, active = 1)
@@ -78,37 +68,57 @@ type peer struct {
 
 	lacking map[common.Hash]struct{} // Set of hashes not to request (didn't have previously)
 
-	currentHead currentHeadRetrievalFn // Method to fetch the currently known head of the peer
-
-	getRelHeaders  relativeHeaderFetcherFn // [eth/62] Method to retrieve a batch of headers from an origin hash
-	getAbsHeaders  absoluteHeaderFetcherFn // [eth/62] Method to retrieve a batch of headers from an absolute position
-	getBlockBodies blockBodyFetcherFn      // [eth/62] Method to retrieve a batch of block bodies
-
-	getReceipts receiptFetcherFn // [eth/63] Method to retrieve a batch of block transaction receipts
-	getNodeData stateFetcherFn   // [eth/63] Method to retrieve a batch of state trie data
+	peer Peer
 
 	version int        // Eth protocol version number to switch strategies
 	log     log.Logger // Contextual logger to add extra infos to peer logs
 	lock    sync.RWMutex
 }
 
-// newPeer create a new downloader peer, with specific hash and block retrieval
-// mechanisms.
-func newPeer(id string, version int, currentHead currentHeadRetrievalFn,
-	getRelHeaders relativeHeaderFetcherFn, getAbsHeaders absoluteHeaderFetcherFn, getBlockBodies blockBodyFetcherFn,
-	getReceipts receiptFetcherFn, getNodeData stateFetcherFn, logger log.Logger) *peer {
+// LightPeer encapsulates the methods required to synchronise with a remote light peer.
+type LightPeer interface {
+	Head() (common.Hash, *big.Int)
+	RequestHeadersByHash(common.Hash, int, int, bool) error
+	RequestHeadersByNumber(uint64, int, int, bool) error
+}
+
+// Peer encapsulates the methods required to synchronise with a remote full peer.
+type Peer interface {
+	LightPeer
+	RequestBodies([]common.Hash) error
+	RequestReceipts([]common.Hash) error
+	RequestNodeData([]common.Hash) error
+}
+
+// lightPeerWrapper wraps a LightPeer struct, stubbing out the Peer-only methods.
+type lightPeerWrapper struct {
+	peer LightPeer
+}
+
+func (w *lightPeerWrapper) Head() (common.Hash, *big.Int) { return w.peer.Head() }
+func (w *lightPeerWrapper) RequestHeadersByHash(h common.Hash, amount int, skip int, reverse bool) error {
+	return w.peer.RequestHeadersByHash(h, amount, skip, reverse)
+}
+func (w *lightPeerWrapper) RequestHeadersByNumber(i uint64, amount int, skip int, reverse bool) error {
+	return w.peer.RequestHeadersByNumber(i, amount, skip, reverse)
+}
+func (w *lightPeerWrapper) RequestBodies([]common.Hash) error {
+	panic("RequestBodies not supported in light client mode sync")
+}
+func (w *lightPeerWrapper) RequestReceipts([]common.Hash) error {
+	panic("RequestReceipts not supported in light client mode sync")
+}
+func (w *lightPeerWrapper) RequestNodeData([]common.Hash) error {
+	panic("RequestNodeData not supported in light client mode sync")
+}
 
-	return &peer{
+// newPeerConnection creates a new downloader peer.
+func newPeerConnection(id string, version int, peer Peer, logger log.Logger) *peerConnection {
+	return &peerConnection{
 		id:      id,
 		lacking: make(map[common.Hash]struct{}),
 
-		currentHead:    currentHead,
-		getRelHeaders:  getRelHeaders,
-		getAbsHeaders:  getAbsHeaders,
-		getBlockBodies: getBlockBodies,
-
-		getReceipts: getReceipts,
-		getNodeData: getNodeData,
+		peer: peer,
 
 		version: version,
 		log:     logger,
@@ -116,7 +126,7 @@ func newPeer(id string, version int, currentHead currentHeadRetrievalFn,
 }
 
 // Reset clears the internal state of a peer entity.
-func (p *peer) Reset() {
+func (p *peerConnection) Reset() {
 	p.lock.Lock()
 	defer p.lock.Unlock()
 
@@ -134,7 +144,7 @@ func (p *peer) Reset() {
 }
 
 // FetchHeaders sends a header retrieval request to the remote peer.
-func (p *peer) FetchHeaders(from uint64, count int) error {
+func (p *peerConnection) FetchHeaders(from uint64, count int) error {
 	// Sanity check the protocol version
 	if p.version < 62 {
 		panic(fmt.Sprintf("header fetch [eth/62+] requested on eth/%d", p.version))
@@ -146,13 +156,13 @@ func (p *peer) FetchHeaders(from uint64, count int) error {
 	p.headerStarted = time.Now()
 
 	// Issue the header retrieval request (absolut upwards without gaps)
-	go p.getAbsHeaders(from, count, 0, false)
+	go p.peer.RequestHeadersByNumber(from, count, 0, false)
 
 	return nil
 }
 
 // FetchBodies sends a block body retrieval request to the remote peer.
-func (p *peer) FetchBodies(request *fetchRequest) error {
+func (p *peerConnection) FetchBodies(request *fetchRequest) error {
 	// Sanity check the protocol version
 	if p.version < 62 {
 		panic(fmt.Sprintf("body fetch [eth/62+] requested on eth/%d", p.version))
@@ -168,13 +178,13 @@ func (p *peer) FetchBodies(request *fetchRequest) error {
 	for _, header := range request.Headers {
 		hashes = append(hashes, header.Hash())
 	}
-	go p.getBlockBodies(hashes)
+	go p.peer.RequestBodies(hashes)
 
 	return nil
 }
 
 // FetchReceipts sends a receipt retrieval request to the remote peer.
-func (p *peer) FetchReceipts(request *fetchRequest) error {
+func (p *peerConnection) FetchReceipts(request *fetchRequest) error {
 	// Sanity check the protocol version
 	if p.version < 63 {
 		panic(fmt.Sprintf("body fetch [eth/63+] requested on eth/%d", p.version))
@@ -190,13 +200,13 @@ func (p *peer) FetchReceipts(request *fetchRequest) error {
 	for _, header := range request.Headers {
 		hashes = append(hashes, header.Hash())
 	}
-	go p.getReceipts(hashes)
+	go p.peer.RequestReceipts(hashes)
 
 	return nil
 }
 
 // FetchNodeData sends a node state data retrieval request to the remote peer.
-func (p *peer) FetchNodeData(hashes []common.Hash) error {
+func (p *peerConnection) FetchNodeData(hashes []common.Hash) error {
 	// Sanity check the protocol version
 	if p.version < 63 {
 		panic(fmt.Sprintf("node data fetch [eth/63+] requested on eth/%d", p.version))
@@ -206,48 +216,50 @@ func (p *peer) FetchNodeData(hashes []common.Hash) error {
 		return errAlreadyFetching
 	}
 	p.stateStarted = time.Now()
-	go p.getNodeData(hashes)
+
+	go p.peer.RequestNodeData(hashes)
+
 	return nil
 }
 
 // SetHeadersIdle sets the peer to idle, allowing it to execute new header retrieval
 // requests. Its estimated header retrieval throughput is updated with that measured
 // just now.
-func (p *peer) SetHeadersIdle(delivered int) {
+func (p *peerConnection) SetHeadersIdle(delivered int) {
 	p.setIdle(p.headerStarted, delivered, &p.headerThroughput, &p.headerIdle)
 }
 
 // SetBlocksIdle sets the peer to idle, allowing it to execute new block retrieval
 // requests. Its estimated block retrieval throughput is updated with that measured
 // just now.
-func (p *peer) SetBlocksIdle(delivered int) {
+func (p *peerConnection) SetBlocksIdle(delivered int) {
 	p.setIdle(p.blockStarted, delivered, &p.blockThroughput, &p.blockIdle)
 }
 
 // SetBodiesIdle sets the peer to idle, allowing it to execute block body retrieval
 // requests. Its estimated body retrieval throughput is updated with that measured
 // just now.
-func (p *peer) SetBodiesIdle(delivered int) {
+func (p *peerConnection) SetBodiesIdle(delivered int) {
 	p.setIdle(p.blockStarted, delivered, &p.blockThroughput, &p.blockIdle)
 }
 
 // SetReceiptsIdle sets the peer to idle, allowing it to execute new receipt
 // retrieval requests. Its estimated receipt retrieval throughput is updated
 // with that measured just now.
-func (p *peer) SetReceiptsIdle(delivered int) {
+func (p *peerConnection) SetReceiptsIdle(delivered int) {
 	p.setIdle(p.receiptStarted, delivered, &p.receiptThroughput, &p.receiptIdle)
 }
 
 // SetNodeDataIdle sets the peer to idle, allowing it to execute new state trie
 // data retrieval requests. Its estimated state retrieval throughput is updated
 // with that measured just now.
-func (p *peer) SetNodeDataIdle(delivered int) {
+func (p *peerConnection) SetNodeDataIdle(delivered int) {
 	p.setIdle(p.stateStarted, delivered, &p.stateThroughput, &p.stateIdle)
 }
 
 // setIdle sets the peer to idle, allowing it to execute new retrieval requests.
 // Its estimated retrieval throughput is updated with that measured just now.
-func (p *peer) setIdle(started time.Time, delivered int, throughput *float64, idle *int32) {
+func (p *peerConnection) setIdle(started time.Time, delivered int, throughput *float64, idle *int32) {
 	// Irrelevant of the scaling, make sure the peer ends up idle
 	defer atomic.StoreInt32(idle, 0)
 
@@ -274,7 +286,7 @@ func (p *peer) setIdle(started time.Time, delivered int, throughput *float64, id
 
 // HeaderCapacity retrieves the peers header download allowance based on its
 // previously discovered throughput.
-func (p *peer) HeaderCapacity(targetRTT time.Duration) int {
+func (p *peerConnection) HeaderCapacity(targetRTT time.Duration) int {
 	p.lock.RLock()
 	defer p.lock.RUnlock()
 
@@ -283,7 +295,7 @@ func (p *peer) HeaderCapacity(targetRTT time.Duration) int {
 
 // BlockCapacity retrieves the peers block download allowance based on its
 // previously discovered throughput.
-func (p *peer) BlockCapacity(targetRTT time.Duration) int {
+func (p *peerConnection) BlockCapacity(targetRTT time.Duration) int {
 	p.lock.RLock()
 	defer p.lock.RUnlock()
 
@@ -292,7 +304,7 @@ func (p *peer) BlockCapacity(targetRTT time.Duration) int {
 
 // ReceiptCapacity retrieves the peers receipt download allowance based on its
 // previously discovered throughput.
-func (p *peer) ReceiptCapacity(targetRTT time.Duration) int {
+func (p *peerConnection) ReceiptCapacity(targetRTT time.Duration) int {
 	p.lock.RLock()
 	defer p.lock.RUnlock()
 
@@ -301,7 +313,7 @@ func (p *peer) ReceiptCapacity(targetRTT time.Duration) int {
 
 // NodeDataCapacity retrieves the peers state download allowance based on its
 // previously discovered throughput.
-func (p *peer) NodeDataCapacity(targetRTT time.Duration) int {
+func (p *peerConnection) NodeDataCapacity(targetRTT time.Duration) int {
 	p.lock.RLock()
 	defer p.lock.RUnlock()
 
@@ -311,7 +323,7 @@ func (p *peer) NodeDataCapacity(targetRTT time.Duration) int {
 // MarkLacking appends a new entity to the set of items (blocks, receipts, states)
 // that a peer is known not to have (i.e. have been requested before). If the
 // set reaches its maximum allowed capacity, items are randomly dropped off.
-func (p *peer) MarkLacking(hash common.Hash) {
+func (p *peerConnection) MarkLacking(hash common.Hash) {
 	p.lock.Lock()
 	defer p.lock.Unlock()
 
@@ -326,7 +338,7 @@ func (p *peer) MarkLacking(hash common.Hash) {
 
 // Lacks retrieves whether the hash of a blockchain item is on the peers lacking
 // list (i.e. whether we know that the peer does not have it).
-func (p *peer) Lacks(hash common.Hash) bool {
+func (p *peerConnection) Lacks(hash common.Hash) bool {
 	p.lock.RLock()
 	defer p.lock.RUnlock()
 
@@ -337,7 +349,7 @@ func (p *peer) Lacks(hash common.Hash) bool {
 // peerSet represents the collection of active peer participating in the chain
 // download procedure.
 type peerSet struct {
-	peers       map[string]*peer
+	peers       map[string]*peerConnection
 	newPeerFeed event.Feed
 	lock        sync.RWMutex
 }
@@ -345,11 +357,11 @@ type peerSet struct {
 // newPeerSet creates a new peer set top track the active download sources.
 func newPeerSet() *peerSet {
 	return &peerSet{
-		peers: make(map[string]*peer),
+		peers: make(map[string]*peerConnection),
 	}
 }
 
-func (ps *peerSet) SubscribeNewPeers(ch chan<- *peer) event.Subscription {
+func (ps *peerSet) SubscribeNewPeers(ch chan<- *peerConnection) event.Subscription {
 	return ps.newPeerFeed.Subscribe(ch)
 }
 
@@ -370,7 +382,7 @@ func (ps *peerSet) Reset() {
 // The method also sets the starting throughput values of the new peer to the
 // average of all existing peers, to give it a realistic chance of being used
 // for data retrievals.
-func (ps *peerSet) Register(p *peer) error {
+func (ps *peerSet) Register(p *peerConnection) error {
 	// Retrieve the current median RTT as a sane default
 	p.rtt = ps.medianRTT()
 
@@ -417,7 +429,7 @@ func (ps *peerSet) Unregister(id string) error {
 }
 
 // Peer retrieves the registered peer with the given id.
-func (ps *peerSet) Peer(id string) *peer {
+func (ps *peerSet) Peer(id string) *peerConnection {
 	ps.lock.RLock()
 	defer ps.lock.RUnlock()
 
@@ -433,11 +445,11 @@ func (ps *peerSet) Len() int {
 }
 
 // AllPeers retrieves a flat list of all the peers within the set.
-func (ps *peerSet) AllPeers() []*peer {
+func (ps *peerSet) AllPeers() []*peerConnection {
 	ps.lock.RLock()
 	defer ps.lock.RUnlock()
 
-	list := make([]*peer, 0, len(ps.peers))
+	list := make([]*peerConnection, 0, len(ps.peers))
 	for _, p := range ps.peers {
 		list = append(list, p)
 	}
@@ -446,11 +458,11 @@ func (ps *peerSet) AllPeers() []*peer {
 
 // HeaderIdlePeers retrieves a flat list of all the currently header-idle peers
 // within the active peer set, ordered by their reputation.
-func (ps *peerSet) HeaderIdlePeers() ([]*peer, int) {
-	idle := func(p *peer) bool {
+func (ps *peerSet) HeaderIdlePeers() ([]*peerConnection, int) {
+	idle := func(p *peerConnection) bool {
 		return atomic.LoadInt32(&p.headerIdle) == 0
 	}
-	throughput := func(p *peer) float64 {
+	throughput := func(p *peerConnection) float64 {
 		p.lock.RLock()
 		defer p.lock.RUnlock()
 		return p.headerThroughput
@@ -460,11 +472,11 @@ func (ps *peerSet) HeaderIdlePeers() ([]*peer, int) {
 
 // BodyIdlePeers retrieves a flat list of all the currently body-idle peers within
 // the active peer set, ordered by their reputation.
-func (ps *peerSet) BodyIdlePeers() ([]*peer, int) {
-	idle := func(p *peer) bool {
+func (ps *peerSet) BodyIdlePeers() ([]*peerConnection, int) {
+	idle := func(p *peerConnection) bool {
 		return atomic.LoadInt32(&p.blockIdle) == 0
 	}
-	throughput := func(p *peer) float64 {
+	throughput := func(p *peerConnection) float64 {
 		p.lock.RLock()
 		defer p.lock.RUnlock()
 		return p.blockThroughput
@@ -474,11 +486,11 @@ func (ps *peerSet) BodyIdlePeers() ([]*peer, int) {
 
 // ReceiptIdlePeers retrieves a flat list of all the currently receipt-idle peers
 // within the active peer set, ordered by their reputation.
-func (ps *peerSet) ReceiptIdlePeers() ([]*peer, int) {
-	idle := func(p *peer) bool {
+func (ps *peerSet) ReceiptIdlePeers() ([]*peerConnection, int) {
+	idle := func(p *peerConnection) bool {
 		return atomic.LoadInt32(&p.receiptIdle) == 0
 	}
-	throughput := func(p *peer) float64 {
+	throughput := func(p *peerConnection) float64 {
 		p.lock.RLock()
 		defer p.lock.RUnlock()
 		return p.receiptThroughput
@@ -488,11 +500,11 @@ func (ps *peerSet) ReceiptIdlePeers() ([]*peer, int) {
 
 // NodeDataIdlePeers retrieves a flat list of all the currently node-data-idle
 // peers within the active peer set, ordered by their reputation.
-func (ps *peerSet) NodeDataIdlePeers() ([]*peer, int) {
-	idle := func(p *peer) bool {
+func (ps *peerSet) NodeDataIdlePeers() ([]*peerConnection, int) {
+	idle := func(p *peerConnection) bool {
 		return atomic.LoadInt32(&p.stateIdle) == 0
 	}
-	throughput := func(p *peer) float64 {
+	throughput := func(p *peerConnection) float64 {
 		p.lock.RLock()
 		defer p.lock.RUnlock()
 		return p.stateThroughput
@@ -503,11 +515,11 @@ func (ps *peerSet) NodeDataIdlePeers() ([]*peer, int) {
 // idlePeers retrieves a flat list of all currently idle peers satisfying the
 // protocol version constraints, using the provided function to check idleness.
 // The resulting set of peers are sorted by their measure throughput.
-func (ps *peerSet) idlePeers(minProtocol, maxProtocol int, idleCheck func(*peer) bool, throughput func(*peer) float64) ([]*peer, int) {
+func (ps *peerSet) idlePeers(minProtocol, maxProtocol int, idleCheck func(*peerConnection) bool, throughput func(*peerConnection) float64) ([]*peerConnection, int) {
 	ps.lock.RLock()
 	defer ps.lock.RUnlock()
 
-	idle, total := make([]*peer, 0, len(ps.peers)), 0
+	idle, total := make([]*peerConnection, 0, len(ps.peers)), 0
 	for _, p := range ps.peers {
 		if p.version >= minProtocol && p.version <= maxProtocol {
 			if idleCheck(p) {

+ 5 - 5
eth/downloader/queue.go

@@ -41,7 +41,7 @@ var (
 
 // fetchRequest is a currently running data retrieval operation.
 type fetchRequest struct {
-	Peer    *peer               // Peer to which the request was sent
+	Peer    *peerConnection     // Peer to which the request was sent
 	From    uint64              // [eth/62] Requested chain element index (used for skeleton fills only)
 	Hashes  map[common.Hash]int // [eth/61] Requested hashes with their insertion index (priority)
 	Headers []*types.Header     // [eth/62] Requested headers, sorted by request order
@@ -391,7 +391,7 @@ func (q *queue) countProcessableItems() int {
 
 // ReserveHeaders reserves a set of headers for the given peer, skipping any
 // previously failed batches.
-func (q *queue) ReserveHeaders(p *peer, count int) *fetchRequest {
+func (q *queue) ReserveHeaders(p *peerConnection, count int) *fetchRequest {
 	q.lock.Lock()
 	defer q.lock.Unlock()
 
@@ -432,7 +432,7 @@ func (q *queue) ReserveHeaders(p *peer, count int) *fetchRequest {
 // ReserveBodies reserves a set of body fetches for the given peer, skipping any
 // previously failed downloads. Beside the next batch of needed fetches, it also
 // returns a flag whether empty blocks were queued requiring processing.
-func (q *queue) ReserveBodies(p *peer, count int) (*fetchRequest, bool, error) {
+func (q *queue) ReserveBodies(p *peerConnection, count int) (*fetchRequest, bool, error) {
 	isNoop := func(header *types.Header) bool {
 		return header.TxHash == types.EmptyRootHash && header.UncleHash == types.EmptyUncleHash
 	}
@@ -445,7 +445,7 @@ func (q *queue) ReserveBodies(p *peer, count int) (*fetchRequest, bool, error) {
 // ReserveReceipts reserves a set of receipt fetches for the given peer, skipping
 // any previously failed downloads. Beside the next batch of needed fetches, it
 // also returns a flag whether empty receipts were queued requiring importing.
-func (q *queue) ReserveReceipts(p *peer, count int) (*fetchRequest, bool, error) {
+func (q *queue) ReserveReceipts(p *peerConnection, count int) (*fetchRequest, bool, error) {
 	isNoop := func(header *types.Header) bool {
 		return header.ReceiptHash == types.EmptyRootHash
 	}
@@ -462,7 +462,7 @@ func (q *queue) ReserveReceipts(p *peer, count int) (*fetchRequest, bool, error)
 // Note, this method expects the queue lock to be already held for writing. The
 // reason the lock is not obtained in here is because the parameters already need
 // to access the queue, so they already need a lock anyway.
-func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque,
+func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque,
 	pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}, isNoop func(*types.Header) bool) (*fetchRequest, bool, error) {
 	// Short circuit if the pool has been depleted, or if the peer's already
 	// downloading something (sanity check not to corrupt state)

+ 2 - 2
eth/downloader/statesync.go

@@ -37,7 +37,7 @@ type stateReq struct {
 	tasks    map[common.Hash]*stateTask // Download tasks to track previous attempts
 	timeout  time.Duration              // Maximum round trip time for this to complete
 	timer    *time.Timer                // Timer to fire when the RTT timeout expires
-	peer     *peer                      // Peer that we're requesting from
+	peer     *peerConnection            // Peer that we're requesting from
 	response [][]byte                   // Response data of the peer (nil for timeouts)
 }
 
@@ -246,7 +246,7 @@ func (s *stateSync) Cancel() error {
 // and timeouts.
 func (s *stateSync) loop() error {
 	// Listen for new peer events to assign tasks to them
-	newPeer := make(chan *peer, 1024)
+	newPeer := make(chan *peerConnection, 1024)
 	peerSub := s.d.peers.SubscribeNewPeers(newPeer)
 	defer peerSub.Unsubscribe()
 

+ 0 - 41
eth/downloader/types.go

@@ -18,51 +18,10 @@ package downloader
 
 import (
 	"fmt"
-	"math/big"
 
-	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/types"
 )
 
-// headerCheckFn is a callback type for verifying a header's presence in the local chain.
-type headerCheckFn func(common.Hash) bool
-
-// blockAndStateCheckFn is a callback type for verifying block and associated states' presence in the local chain.
-type blockAndStateCheckFn func(common.Hash) bool
-
-// headerRetrievalFn is a callback type for retrieving a header from the local chain.
-type headerRetrievalFn func(common.Hash) *types.Header
-
-// blockRetrievalFn is a callback type for retrieving a block from the local chain.
-type blockRetrievalFn func(common.Hash) *types.Block
-
-// headHeaderRetrievalFn is a callback type for retrieving the head header from the local chain.
-type headHeaderRetrievalFn func() *types.Header
-
-// headBlockRetrievalFn is a callback type for retrieving the head block from the local chain.
-type headBlockRetrievalFn func() *types.Block
-
-// headFastBlockRetrievalFn is a callback type for retrieving the head fast block from the local chain.
-type headFastBlockRetrievalFn func() *types.Block
-
-// headBlockCommitterFn is a callback for directly committing the head block to a certain entity.
-type headBlockCommitterFn func(common.Hash) error
-
-// tdRetrievalFn is a callback type for retrieving the total difficulty of a local block.
-type tdRetrievalFn func(common.Hash) *big.Int
-
-// headerChainInsertFn is a callback type to insert a batch of headers into the local chain.
-type headerChainInsertFn func([]*types.Header, int) (int, error)
-
-// blockChainInsertFn is a callback type to insert a batch of blocks into the local chain.
-type blockChainInsertFn func(types.Blocks) (int, error)
-
-// receiptChainInsertFn is a callback type to insert a batch of receipts into the local chain.
-type receiptChainInsertFn func(types.Blocks, []types.Receipts) (int, error)
-
-// chainRollbackFn is a callback type to remove a few recently added elements from the local chain.
-type chainRollbackFn func([]common.Hash)
-
 // peerDropFn is a callback type for dropping a peer detected as malicious.
 type peerDropFn func(id string)
 

+ 2 - 5
eth/handler.go

@@ -157,10 +157,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
 		return nil, errIncompatibleConfig
 	}
 	// Construct the different synchronisation mechanisms
-	manager.downloader = downloader.New(mode, chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlockAndState, blockchain.GetHeaderByHash,
-		blockchain.GetBlockByHash, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead,
-		blockchain.GetTdByHash, blockchain.InsertHeaderChain, manager.blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback,
-		manager.removePeer)
+	manager.downloader = downloader.New(mode, chaindb, manager.eventMux, blockchain, nil, manager.removePeer)
 
 	validator := func(header *types.Header) error {
 		return engine.VerifyHeader(blockchain, header, true)
@@ -268,7 +265,7 @@ func (pm *ProtocolManager) handle(p *peer) error {
 	defer pm.removePeer(p.id)
 
 	// Register the peer in the downloader. If the downloader considers it banned, we disconnect
-	if err := pm.downloader.RegisterPeer(p.id, p.version, p.Head, p.RequestHeadersByHash, p.RequestHeadersByNumber, p.RequestBodies, p.RequestReceipts, p.RequestNodeData); err != nil {
+	if err := pm.downloader.RegisterPeer(p.id, p.version, p); err != nil {
 		return err
 	}
 	// Propagate existing transactions. new transactions appearing

+ 61 - 50
les/handler.go

@@ -206,9 +206,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network
 	}
 
 	if lightSync {
-		manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, blockchain.HasHeader, nil, blockchain.GetHeaderByHash,
-			nil, blockchain.CurrentHeader, nil, nil, nil, blockchain.GetTdByHash,
-			blockchain.InsertHeaderChain, nil, nil, blockchain.Rollback, removePeer)
+		manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, nil, blockchain, removePeer)
 		manager.peers.notify((*downloaderPeerNotify)(manager))
 		manager.fetcher = newLightFetcher(manager)
 	}
@@ -840,57 +838,70 @@ func (self *ProtocolManager) NodeInfo() *eth.EthNodeInfo {
 // downloaderPeerNotify implements peerSetNotify
 type downloaderPeerNotify ProtocolManager
 
-func (d *downloaderPeerNotify) registerPeer(p *peer) {
-	pm := (*ProtocolManager)(d)
+type peerConnection struct {
+	manager *ProtocolManager
+	peer    *peer
+}
 
-	requestHeadersByHash := func(origin common.Hash, amount int, skip int, reverse bool) error {
-		reqID := genReqID()
-		rq := &distReq{
-			getCost: func(dp distPeer) uint64 {
-				peer := dp.(*peer)
-				return peer.GetRequestCost(GetBlockHeadersMsg, amount)
-			},
-			canSend: func(dp distPeer) bool {
-				return dp.(*peer) == p
-			},
-			request: func(dp distPeer) func() {
-				peer := dp.(*peer)
-				cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
-				peer.fcServer.QueueRequest(reqID, cost)
-				return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) }
-			},
-		}
-		_, ok := <-pm.reqDist.queue(rq)
-		if !ok {
-			return ErrNoPeers
-		}
-		return nil
+func (pc *peerConnection) Head() (common.Hash, *big.Int) {
+	return pc.peer.HeadAndTd()
+}
+
+func (pc *peerConnection) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
+	reqID := genReqID()
+	rq := &distReq{
+		getCost: func(dp distPeer) uint64 {
+			peer := dp.(*peer)
+			return peer.GetRequestCost(GetBlockHeadersMsg, amount)
+		},
+		canSend: func(dp distPeer) bool {
+			return dp.(*peer) == pc.peer
+		},
+		request: func(dp distPeer) func() {
+			peer := dp.(*peer)
+			cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
+			peer.fcServer.QueueRequest(reqID, cost)
+			return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) }
+		},
 	}
-	requestHeadersByNumber := func(origin uint64, amount int, skip int, reverse bool) error {
-		reqID := genReqID()
-		rq := &distReq{
-			getCost: func(dp distPeer) uint64 {
-				peer := dp.(*peer)
-				return peer.GetRequestCost(GetBlockHeadersMsg, amount)
-			},
-			canSend: func(dp distPeer) bool {
-				return dp.(*peer) == p
-			},
-			request: func(dp distPeer) func() {
-				peer := dp.(*peer)
-				cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
-				peer.fcServer.QueueRequest(reqID, cost)
-				return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) }
-			},
-		}
-		_, ok := <-pm.reqDist.queue(rq)
-		if !ok {
-			return ErrNoPeers
-		}
-		return nil
+	_, ok := <-pc.manager.reqDist.queue(rq)
+	if !ok {
+		return ErrNoPeers
+	}
+	return nil
+}
+
+func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
+	reqID := genReqID()
+	rq := &distReq{
+		getCost: func(dp distPeer) uint64 {
+			peer := dp.(*peer)
+			return peer.GetRequestCost(GetBlockHeadersMsg, amount)
+		},
+		canSend: func(dp distPeer) bool {
+			return dp.(*peer) == pc.peer
+		},
+		request: func(dp distPeer) func() {
+			peer := dp.(*peer)
+			cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
+			peer.fcServer.QueueRequest(reqID, cost)
+			return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) }
+		},
 	}
+	_, ok := <-pc.manager.reqDist.queue(rq)
+	if !ok {
+		return ErrNoPeers
+	}
+	return nil
+}
 
-	pm.downloader.RegisterPeer(p.id, ethVersion, p.HeadAndTd, requestHeadersByHash, requestHeadersByNumber, nil, nil, nil)
+func (d *downloaderPeerNotify) registerPeer(p *peer) {
+	pm := (*ProtocolManager)(d)
+	pc := &peerConnection{
+		manager: pm,
+		peer:    p,
+	}
+	pm.downloader.RegisterLightPeer(p.id, ethVersion, pc)
 }
 
 func (d *downloaderPeerNotify) unregisterPeer(p *peer) {