浏览代码

les: handler separation (#19639)

les: handler separation
gary rong 6 年之前
父节点
当前提交
2ed729d38e
共有 31 个文件被更改,包括 2364 次插入2512 次删除
  1. 22 0
      core/blockchain.go
  2. 5 9
      les/api.go
  3. 1 1
      les/api_backend.go
  4. 5 13
      les/api_test.go
  5. 24 23
      les/benchmark.go
  6. 2 1
      les/bloombits.go
  7. 58 73
      les/client.go
  8. 401 0
      les/client_handler.go
  9. 47 22
      les/commons.go
  10. 8 3
      les/costtracker.go
  11. 26 11
      les/distributor.go
  12. 1 1
      les/distributor_test.go
  13. 37 38
      les/fetcher.go
  14. 0 168
      les/fetcher_test.go
  15. 0 1293
      les/handler.go
  16. 108 90
      les/handler_test.go
  17. 63 27
      les/metrics.go
  18. 4 1
      les/odr.go
  19. 22 16
      les/odr_test.go
  20. 25 23
      les/peer.go
  21. 38 42
      les/peer_test.go
  22. 10 18
      les/request_test.go
  23. 143 197
      les/server.go
  24. 921 0
      les/server_handler.go
  25. 36 25
      les/serverpool.go
  26. 21 50
      les/sync.go
  27. 8 9
      les/sync_test.go
  28. 247 201
      les/test_helper.go
  29. 75 151
      les/ulc_test.go
  30. 3 3
      light/odr_util.go
  31. 3 3
      light/postprocess.go

+ 22 - 0
core/blockchain.go

@@ -75,6 +75,7 @@ const (
 	bodyCacheLimit      = 256
 	bodyCacheLimit      = 256
 	blockCacheLimit     = 256
 	blockCacheLimit     = 256
 	receiptsCacheLimit  = 32
 	receiptsCacheLimit  = 32
+	txLookupCacheLimit  = 1024
 	maxFutureBlocks     = 256
 	maxFutureBlocks     = 256
 	maxTimeFutureBlocks = 30
 	maxTimeFutureBlocks = 30
 	badBlockLimit       = 10
 	badBlockLimit       = 10
@@ -155,6 +156,7 @@ type BlockChain struct {
 	bodyRLPCache  *lru.Cache     // Cache for the most recent block bodies in RLP encoded format
 	bodyRLPCache  *lru.Cache     // Cache for the most recent block bodies in RLP encoded format
 	receiptsCache *lru.Cache     // Cache for the most recent receipts per block
 	receiptsCache *lru.Cache     // Cache for the most recent receipts per block
 	blockCache    *lru.Cache     // Cache for the most recent entire blocks
 	blockCache    *lru.Cache     // Cache for the most recent entire blocks
+	txLookupCache *lru.Cache     // Cache for the most recent transaction lookup data.
 	futureBlocks  *lru.Cache     // future blocks are blocks added for later processing
 	futureBlocks  *lru.Cache     // future blocks are blocks added for later processing
 
 
 	quit    chan struct{} // blockchain quit channel
 	quit    chan struct{} // blockchain quit channel
@@ -189,6 +191,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
 	bodyRLPCache, _ := lru.New(bodyCacheLimit)
 	bodyRLPCache, _ := lru.New(bodyCacheLimit)
 	receiptsCache, _ := lru.New(receiptsCacheLimit)
 	receiptsCache, _ := lru.New(receiptsCacheLimit)
 	blockCache, _ := lru.New(blockCacheLimit)
 	blockCache, _ := lru.New(blockCacheLimit)
+	txLookupCache, _ := lru.New(txLookupCacheLimit)
 	futureBlocks, _ := lru.New(maxFutureBlocks)
 	futureBlocks, _ := lru.New(maxFutureBlocks)
 	badBlocks, _ := lru.New(badBlockLimit)
 	badBlocks, _ := lru.New(badBlockLimit)
 
 
@@ -204,6 +207,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
 		bodyRLPCache:   bodyRLPCache,
 		bodyRLPCache:   bodyRLPCache,
 		receiptsCache:  receiptsCache,
 		receiptsCache:  receiptsCache,
 		blockCache:     blockCache,
 		blockCache:     blockCache,
+		txLookupCache:  txLookupCache,
 		futureBlocks:   futureBlocks,
 		futureBlocks:   futureBlocks,
 		engine:         engine,
 		engine:         engine,
 		vmConfig:       vmConfig,
 		vmConfig:       vmConfig,
@@ -440,6 +444,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
 	bc.bodyRLPCache.Purge()
 	bc.bodyRLPCache.Purge()
 	bc.receiptsCache.Purge()
 	bc.receiptsCache.Purge()
 	bc.blockCache.Purge()
 	bc.blockCache.Purge()
+	bc.txLookupCache.Purge()
 	bc.futureBlocks.Purge()
 	bc.futureBlocks.Purge()
 
 
 	return bc.loadLastState()
 	return bc.loadLastState()
@@ -921,6 +926,7 @@ func (bc *BlockChain) truncateAncient(head uint64) error {
 	bc.bodyRLPCache.Purge()
 	bc.bodyRLPCache.Purge()
 	bc.receiptsCache.Purge()
 	bc.receiptsCache.Purge()
 	bc.blockCache.Purge()
 	bc.blockCache.Purge()
+	bc.txLookupCache.Purge()
 	bc.futureBlocks.Purge()
 	bc.futureBlocks.Purge()
 
 
 	log.Info("Rewind ancient data", "number", head)
 	log.Info("Rewind ancient data", "number", head)
@@ -2151,6 +2157,22 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header {
 	return bc.hc.GetHeaderByNumber(number)
 	return bc.hc.GetHeaderByNumber(number)
 }
 }
 
 
+// GetTransactionLookup retrieves the lookup associate with the given transaction
+// hash from the cache or database.
+func (bc *BlockChain) GetTransactionLookup(hash common.Hash) *rawdb.LegacyTxLookupEntry {
+	// Short circuit if the txlookup already in the cache, retrieve otherwise
+	if lookup, exist := bc.txLookupCache.Get(hash); exist {
+		return lookup.(*rawdb.LegacyTxLookupEntry)
+	}
+	tx, blockHash, blockNumber, txIndex := rawdb.ReadTransaction(bc.db, hash)
+	if tx == nil {
+		return nil
+	}
+	lookup := &rawdb.LegacyTxLookupEntry{BlockHash: blockHash, BlockIndex: blockNumber, Index: txIndex}
+	bc.txLookupCache.Add(hash, lookup)
+	return lookup
+}
+
 // Config retrieves the chain's fork configuration.
 // Config retrieves the chain's fork configuration.
 func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig }
 func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig }
 
 

+ 5 - 9
les/api.go

@@ -30,15 +30,11 @@ var (
 // PrivateLightAPI provides an API to access the LES light server or light client.
 // PrivateLightAPI provides an API to access the LES light server or light client.
 type PrivateLightAPI struct {
 type PrivateLightAPI struct {
 	backend *lesCommons
 	backend *lesCommons
-	reg     *checkpointOracle
 }
 }
 
 
 // NewPrivateLightAPI creates a new LES service API.
 // NewPrivateLightAPI creates a new LES service API.
-func NewPrivateLightAPI(backend *lesCommons, reg *checkpointOracle) *PrivateLightAPI {
-	return &PrivateLightAPI{
-		backend: backend,
-		reg:     reg,
-	}
+func NewPrivateLightAPI(backend *lesCommons) *PrivateLightAPI {
+	return &PrivateLightAPI{backend: backend}
 }
 }
 
 
 // LatestCheckpoint returns the latest local checkpoint package.
 // LatestCheckpoint returns the latest local checkpoint package.
@@ -67,7 +63,7 @@ func (api *PrivateLightAPI) LatestCheckpoint() ([4]string, error) {
 //   result[2], 32 bytes hex encoded latest section bloom trie root hash
 //   result[2], 32 bytes hex encoded latest section bloom trie root hash
 func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) {
 func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) {
 	var res [3]string
 	var res [3]string
-	cp := api.backend.getLocalCheckpoint(index)
+	cp := api.backend.localCheckpoint(index)
 	if cp.Empty() {
 	if cp.Empty() {
 		return res, errNoCheckpoint
 		return res, errNoCheckpoint
 	}
 	}
@@ -77,8 +73,8 @@ func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) {
 
 
 // GetCheckpointContractAddress returns the contract contract address in hex format.
 // GetCheckpointContractAddress returns the contract contract address in hex format.
 func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) {
 func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) {
-	if api.reg == nil {
+	if api.backend.oracle == nil {
 		return "", errNotActivated
 		return "", errNotActivated
 	}
 	}
-	return api.reg.config.Address.Hex(), nil
+	return api.backend.oracle.config.Address.Hex(), nil
 }
 }

+ 1 - 1
les/api_backend.go

@@ -54,7 +54,7 @@ func (b *LesApiBackend) CurrentBlock() *types.Block {
 }
 }
 
 
 func (b *LesApiBackend) SetHead(number uint64) {
 func (b *LesApiBackend) SetHead(number uint64) {
-	b.eth.protocolManager.downloader.Cancel()
+	b.eth.handler.downloader.Cancel()
 	b.eth.blockchain.SetHead(number)
 	b.eth.blockchain.SetHead(number)
 }
 }
 
 

+ 5 - 13
les/api_test.go

@@ -78,19 +78,16 @@ func TestCapacityAPI10(t *testing.T) {
 // while connected and going back and forth between free and priority mode with
 // while connected and going back and forth between free and priority mode with
 // the supplied API calls is also thoroughly tested.
 // the supplied API calls is also thoroughly tested.
 func testCapacityAPI(t *testing.T, clientCount int) {
 func testCapacityAPI(t *testing.T, clientCount int) {
+	// Skip test if no data dir specified
 	if testServerDataDir == "" {
 	if testServerDataDir == "" {
-		// Skip test if no data dir specified
 		return
 		return
 	}
 	}
-
 	for !testSim(t, 1, clientCount, []string{testServerDataDir}, nil, func(ctx context.Context, net *simulations.Network, servers []*simulations.Node, clients []*simulations.Node) bool {
 	for !testSim(t, 1, clientCount, []string{testServerDataDir}, nil, func(ctx context.Context, net *simulations.Network, servers []*simulations.Node, clients []*simulations.Node) bool {
 		if len(servers) != 1 {
 		if len(servers) != 1 {
 			t.Fatalf("Invalid number of servers: %d", len(servers))
 			t.Fatalf("Invalid number of servers: %d", len(servers))
 		}
 		}
 		server := servers[0]
 		server := servers[0]
 
 
-		clientRpcClients := make([]*rpc.Client, len(clients))
-
 		serverRpcClient, err := server.Client()
 		serverRpcClient, err := server.Client()
 		if err != nil {
 		if err != nil {
 			t.Fatalf("Failed to obtain rpc client: %v", err)
 			t.Fatalf("Failed to obtain rpc client: %v", err)
@@ -105,13 +102,13 @@ func testCapacityAPI(t *testing.T, clientCount int) {
 		}
 		}
 		freeIdx := rand.Intn(len(clients))
 		freeIdx := rand.Intn(len(clients))
 
 
+		clientRpcClients := make([]*rpc.Client, len(clients))
 		for i, client := range clients {
 		for i, client := range clients {
 			var err error
 			var err error
 			clientRpcClients[i], err = client.Client()
 			clientRpcClients[i], err = client.Client()
 			if err != nil {
 			if err != nil {
 				t.Fatalf("Failed to obtain rpc client: %v", err)
 				t.Fatalf("Failed to obtain rpc client: %v", err)
 			}
 			}
-
 			t.Log("connecting client", i)
 			t.Log("connecting client", i)
 			if i != freeIdx {
 			if i != freeIdx {
 				setCapacity(ctx, t, serverRpcClient, client.ID(), testCap/uint64(len(clients)))
 				setCapacity(ctx, t, serverRpcClient, client.ID(), testCap/uint64(len(clients)))
@@ -138,10 +135,13 @@ func testCapacityAPI(t *testing.T, clientCount int) {
 
 
 		reqCount := make([]uint64, len(clientRpcClients))
 		reqCount := make([]uint64, len(clientRpcClients))
 
 
+		// Send light request like crazy.
 		for i, c := range clientRpcClients {
 		for i, c := range clientRpcClients {
 			wg.Add(1)
 			wg.Add(1)
 			i, c := i, c
 			i, c := i, c
 			go func() {
 			go func() {
+				defer wg.Done()
+
 				queue := make(chan struct{}, 100)
 				queue := make(chan struct{}, 100)
 				reqCount[i] = 0
 				reqCount[i] = 0
 				for {
 				for {
@@ -149,10 +149,8 @@ func testCapacityAPI(t *testing.T, clientCount int) {
 					case queue <- struct{}{}:
 					case queue <- struct{}{}:
 						select {
 						select {
 						case <-stop:
 						case <-stop:
-							wg.Done()
 							return
 							return
 						case <-ctx.Done():
 						case <-ctx.Done():
-							wg.Done()
 							return
 							return
 						default:
 						default:
 							wg.Add(1)
 							wg.Add(1)
@@ -169,10 +167,8 @@ func testCapacityAPI(t *testing.T, clientCount int) {
 							}()
 							}()
 						}
 						}
 					case <-stop:
 					case <-stop:
-						wg.Done()
 						return
 						return
 					case <-ctx.Done():
 					case <-ctx.Done():
-						wg.Done()
 						return
 						return
 					}
 					}
 				}
 				}
@@ -313,12 +309,10 @@ func getHead(ctx context.Context, t *testing.T, client *rpc.Client) (uint64, com
 }
 }
 
 
 func testRequest(ctx context.Context, t *testing.T, client *rpc.Client) bool {
 func testRequest(ctx context.Context, t *testing.T, client *rpc.Client) bool {
-	//res := make(map[string]interface{})
 	var res string
 	var res string
 	var addr common.Address
 	var addr common.Address
 	rand.Read(addr[:])
 	rand.Read(addr[:])
 	c, _ := context.WithTimeout(ctx, time.Second*12)
 	c, _ := context.WithTimeout(ctx, time.Second*12)
-	//	if err := client.CallContext(ctx, &res, "eth_getProof", addr, nil, "latest"); err != nil {
 	err := client.CallContext(c, &res, "eth_getBalance", addr, "latest")
 	err := client.CallContext(c, &res, "eth_getBalance", addr, "latest")
 	if err != nil {
 	if err != nil {
 		t.Log("request error:", err)
 		t.Log("request error:", err)
@@ -418,7 +412,6 @@ func NewNetwork() (*simulations.Network, func(), error) {
 		adapterTeardown()
 		adapterTeardown()
 		net.Shutdown()
 		net.Shutdown()
 	}
 	}
-
 	return net, teardown, nil
 	return net, teardown, nil
 }
 }
 
 
@@ -516,7 +509,6 @@ func newLesServerService(ctx *adapters.ServiceContext) (node.Service, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-
 	server, err := NewLesServer(ethereum, &config)
 	server, err := NewLesServer(ethereum, &config)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 24 - 23
les/benchmark.go

@@ -39,7 +39,7 @@ import (
 // requestBenchmark is an interface for different randomized request generators
 // requestBenchmark is an interface for different randomized request generators
 type requestBenchmark interface {
 type requestBenchmark interface {
 	// init initializes the generator for generating the given number of randomized requests
 	// init initializes the generator for generating the given number of randomized requests
-	init(pm *ProtocolManager, count int) error
+	init(h *serverHandler, count int) error
 	// request initiates sending a single request to the given peer
 	// request initiates sending a single request to the given peer
 	request(peer *peer, index int) error
 	request(peer *peer, index int) error
 }
 }
@@ -52,10 +52,10 @@ type benchmarkBlockHeaders struct {
 	hashes          []common.Hash
 	hashes          []common.Hash
 }
 }
 
 
-func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error {
+func (b *benchmarkBlockHeaders) init(h *serverHandler, count int) error {
 	d := int64(b.amount-1) * int64(b.skip+1)
 	d := int64(b.amount-1) * int64(b.skip+1)
 	b.offset = 0
 	b.offset = 0
-	b.randMax = pm.blockchain.CurrentHeader().Number.Int64() + 1 - d
+	b.randMax = h.blockchain.CurrentHeader().Number.Int64() + 1 - d
 	if b.randMax < 0 {
 	if b.randMax < 0 {
 		return fmt.Errorf("chain is too short")
 		return fmt.Errorf("chain is too short")
 	}
 	}
@@ -65,7 +65,7 @@ func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error {
 	if b.byHash {
 	if b.byHash {
 		b.hashes = make([]common.Hash, count)
 		b.hashes = make([]common.Hash, count)
 		for i := range b.hashes {
 		for i := range b.hashes {
-			b.hashes[i] = rawdb.ReadCanonicalHash(pm.chainDb, uint64(b.offset+rand.Int63n(b.randMax)))
+			b.hashes[i] = rawdb.ReadCanonicalHash(h.chainDb, uint64(b.offset+rand.Int63n(b.randMax)))
 		}
 		}
 	}
 	}
 	return nil
 	return nil
@@ -85,11 +85,11 @@ type benchmarkBodiesOrReceipts struct {
 	hashes   []common.Hash
 	hashes   []common.Hash
 }
 }
 
 
-func (b *benchmarkBodiesOrReceipts) init(pm *ProtocolManager, count int) error {
-	randMax := pm.blockchain.CurrentHeader().Number.Int64() + 1
+func (b *benchmarkBodiesOrReceipts) init(h *serverHandler, count int) error {
+	randMax := h.blockchain.CurrentHeader().Number.Int64() + 1
 	b.hashes = make([]common.Hash, count)
 	b.hashes = make([]common.Hash, count)
 	for i := range b.hashes {
 	for i := range b.hashes {
-		b.hashes[i] = rawdb.ReadCanonicalHash(pm.chainDb, uint64(rand.Int63n(randMax)))
+		b.hashes[i] = rawdb.ReadCanonicalHash(h.chainDb, uint64(rand.Int63n(randMax)))
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -108,8 +108,8 @@ type benchmarkProofsOrCode struct {
 	headHash common.Hash
 	headHash common.Hash
 }
 }
 
 
-func (b *benchmarkProofsOrCode) init(pm *ProtocolManager, count int) error {
-	b.headHash = pm.blockchain.CurrentHeader().Hash()
+func (b *benchmarkProofsOrCode) init(h *serverHandler, count int) error {
+	b.headHash = h.blockchain.CurrentHeader().Hash()
 	return nil
 	return nil
 }
 }
 
 
@@ -130,11 +130,11 @@ type benchmarkHelperTrie struct {
 	sectionCount, headNum uint64
 	sectionCount, headNum uint64
 }
 }
 
 
-func (b *benchmarkHelperTrie) init(pm *ProtocolManager, count int) error {
+func (b *benchmarkHelperTrie) init(h *serverHandler, count int) error {
 	if b.bloom {
 	if b.bloom {
-		b.sectionCount, b.headNum, _ = pm.server.bloomTrieIndexer.Sections()
+		b.sectionCount, b.headNum, _ = h.server.bloomTrieIndexer.Sections()
 	} else {
 	} else {
-		b.sectionCount, _, _ = pm.server.chtIndexer.Sections()
+		b.sectionCount, _, _ = h.server.chtIndexer.Sections()
 		b.headNum = b.sectionCount*params.CHTFrequency - 1
 		b.headNum = b.sectionCount*params.CHTFrequency - 1
 	}
 	}
 	if b.sectionCount == 0 {
 	if b.sectionCount == 0 {
@@ -170,7 +170,7 @@ type benchmarkTxSend struct {
 	txs types.Transactions
 	txs types.Transactions
 }
 }
 
 
-func (b *benchmarkTxSend) init(pm *ProtocolManager, count int) error {
+func (b *benchmarkTxSend) init(h *serverHandler, count int) error {
 	key, _ := crypto.GenerateKey()
 	key, _ := crypto.GenerateKey()
 	addr := crypto.PubkeyToAddress(key.PublicKey)
 	addr := crypto.PubkeyToAddress(key.PublicKey)
 	signer := types.NewEIP155Signer(big.NewInt(18))
 	signer := types.NewEIP155Signer(big.NewInt(18))
@@ -196,7 +196,7 @@ func (b *benchmarkTxSend) request(peer *peer, index int) error {
 // benchmarkTxStatus implements requestBenchmark
 // benchmarkTxStatus implements requestBenchmark
 type benchmarkTxStatus struct{}
 type benchmarkTxStatus struct{}
 
 
-func (b *benchmarkTxStatus) init(pm *ProtocolManager, count int) error {
+func (b *benchmarkTxStatus) init(h *serverHandler, count int) error {
 	return nil
 	return nil
 }
 }
 
 
@@ -217,7 +217,7 @@ type benchmarkSetup struct {
 
 
 // runBenchmark runs a benchmark cycle for all benchmark types in the specified
 // runBenchmark runs a benchmark cycle for all benchmark types in the specified
 // number of passes
 // number of passes
-func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount int, targetTime time.Duration) []*benchmarkSetup {
+func (h *serverHandler) runBenchmark(benchmarks []requestBenchmark, passCount int, targetTime time.Duration) []*benchmarkSetup {
 	setup := make([]*benchmarkSetup, len(benchmarks))
 	setup := make([]*benchmarkSetup, len(benchmarks))
 	for i, b := range benchmarks {
 	for i, b := range benchmarks {
 		setup[i] = &benchmarkSetup{req: b}
 		setup[i] = &benchmarkSetup{req: b}
@@ -239,7 +239,7 @@ func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount
 				if next.totalTime > 0 {
 				if next.totalTime > 0 {
 					count = int(uint64(next.totalCount) * uint64(targetTime) / uint64(next.totalTime))
 					count = int(uint64(next.totalCount) * uint64(targetTime) / uint64(next.totalTime))
 				}
 				}
-				if err := pm.measure(next, count); err != nil {
+				if err := h.measure(next, count); err != nil {
 					next.err = err
 					next.err = err
 				}
 				}
 			}
 			}
@@ -275,14 +275,15 @@ func (m *meteredPipe) WriteMsg(msg p2p.Msg) error {
 
 
 // measure runs a benchmark for a single type in a single pass, with the given
 // measure runs a benchmark for a single type in a single pass, with the given
 // number of requests
 // number of requests
-func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
+func (h *serverHandler) measure(setup *benchmarkSetup, count int) error {
 	clientPipe, serverPipe := p2p.MsgPipe()
 	clientPipe, serverPipe := p2p.MsgPipe()
 	clientMeteredPipe := &meteredPipe{rw: clientPipe}
 	clientMeteredPipe := &meteredPipe{rw: clientPipe}
 	serverMeteredPipe := &meteredPipe{rw: serverPipe}
 	serverMeteredPipe := &meteredPipe{rw: serverPipe}
 	var id enode.ID
 	var id enode.ID
 	rand.Read(id[:])
 	rand.Read(id[:])
-	clientPeer := pm.newPeer(lpv2, NetworkId, p2p.NewPeer(id, "client", nil), clientMeteredPipe)
-	serverPeer := pm.newPeer(lpv2, NetworkId, p2p.NewPeer(id, "server", nil), serverMeteredPipe)
+
+	clientPeer := newPeer(lpv2, NetworkId, false, p2p.NewPeer(id, "client", nil), clientMeteredPipe)
+	serverPeer := newPeer(lpv2, NetworkId, false, p2p.NewPeer(id, "server", nil), serverMeteredPipe)
 	serverPeer.sendQueue = newExecQueue(count)
 	serverPeer.sendQueue = newExecQueue(count)
 	serverPeer.announceType = announceTypeNone
 	serverPeer.announceType = announceTypeNone
 	serverPeer.fcCosts = make(requestCostTable)
 	serverPeer.fcCosts = make(requestCostTable)
@@ -291,10 +292,10 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
 		serverPeer.fcCosts[code] = c
 		serverPeer.fcCosts[code] = c
 	}
 	}
 	serverPeer.fcParams = flowcontrol.ServerParams{BufLimit: 1, MinRecharge: 1}
 	serverPeer.fcParams = flowcontrol.ServerParams{BufLimit: 1, MinRecharge: 1}
-	serverPeer.fcClient = flowcontrol.NewClientNode(pm.server.fcManager, serverPeer.fcParams)
+	serverPeer.fcClient = flowcontrol.NewClientNode(h.server.fcManager, serverPeer.fcParams)
 	defer serverPeer.fcClient.Disconnect()
 	defer serverPeer.fcClient.Disconnect()
 
 
-	if err := setup.req.init(pm, count); err != nil {
+	if err := setup.req.init(h, count); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -311,7 +312,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
 	}()
 	}()
 	go func() {
 	go func() {
 		for i := 0; i < count; i++ {
 		for i := 0; i < count; i++ {
-			if err := pm.handleMsg(serverPeer); err != nil {
+			if err := h.handleMsg(serverPeer); err != nil {
 				errCh <- err
 				errCh <- err
 				return
 				return
 			}
 			}
@@ -336,7 +337,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-	case <-pm.quitSync:
+	case <-h.closeCh:
 		clientPipe.Close()
 		clientPipe.Close()
 		serverPipe.Close()
 		serverPipe.Close()
 		return fmt.Errorf("Benchmark cancelled")
 		return fmt.Errorf("Benchmark cancelled")

+ 2 - 1
les/bloombits.go

@@ -46,9 +46,10 @@ const (
 func (eth *LightEthereum) startBloomHandlers(sectionSize uint64) {
 func (eth *LightEthereum) startBloomHandlers(sectionSize uint64) {
 	for i := 0; i < bloomServiceThreads; i++ {
 	for i := 0; i < bloomServiceThreads; i++ {
 		go func() {
 		go func() {
+			defer eth.wg.Done()
 			for {
 			for {
 				select {
 				select {
-				case <-eth.shutdownChan:
+				case <-eth.closeCh:
 					return
 					return
 
 
 				case request := <-eth.bloomRequests:
 				case request := <-eth.bloomRequests:

+ 58 - 73
les/backend.go → les/client.go

@@ -19,8 +19,6 @@ package les
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"sync"
-	"time"
 
 
 	"github.com/ethereum/go-ethereum/accounts"
 	"github.com/ethereum/go-ethereum/accounts"
 	"github.com/ethereum/go-ethereum/accounts/abi/bind"
 	"github.com/ethereum/go-ethereum/accounts/abi/bind"
@@ -42,7 +40,7 @@ import (
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/node"
 	"github.com/ethereum/go-ethereum/node"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p"
-	"github.com/ethereum/go-ethereum/p2p/discv5"
+	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/params"
 	"github.com/ethereum/go-ethereum/params"
 	"github.com/ethereum/go-ethereum/rpc"
 	"github.com/ethereum/go-ethereum/rpc"
 )
 )
@@ -50,33 +48,23 @@ import (
 type LightEthereum struct {
 type LightEthereum struct {
 	lesCommons
 	lesCommons
 
 
-	odr         *LesOdr
-	chainConfig *params.ChainConfig
-	// Channel for shutting down the service
-	shutdownChan chan bool
-
-	// Handlers
-	peers      *peerSet
-	txPool     *light.TxPool
-	blockchain *light.LightChain
-	serverPool *serverPool
 	reqDist    *requestDistributor
 	reqDist    *requestDistributor
 	retriever  *retrieveManager
 	retriever  *retrieveManager
+	odr        *LesOdr
 	relay      *lesTxRelay
 	relay      *lesTxRelay
+	handler    *clientHandler
+	txPool     *light.TxPool
+	blockchain *light.LightChain
+	serverPool *serverPool
 
 
 	bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests
 	bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests
-	bloomIndexer  *core.ChainIndexer
-
-	ApiBackend *LesApiBackend
+	bloomIndexer  *core.ChainIndexer             // Bloom indexer operating during block imports
 
 
+	ApiBackend     *LesApiBackend
 	eventMux       *event.TypeMux
 	eventMux       *event.TypeMux
 	engine         consensus.Engine
 	engine         consensus.Engine
 	accountManager *accounts.Manager
 	accountManager *accounts.Manager
-
-	networkId     uint64
-	netRPCService *ethapi.PublicNetAPI
-
-	wg sync.WaitGroup
+	netRPCService  *ethapi.PublicNetAPI
 }
 }
 
 
 func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
 func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
@@ -91,26 +79,24 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
 	log.Info("Initialised chain configuration", "config", chainConfig)
 	log.Info("Initialised chain configuration", "config", chainConfig)
 
 
 	peers := newPeerSet()
 	peers := newPeerSet()
-	quitSync := make(chan struct{})
-
 	leth := &LightEthereum{
 	leth := &LightEthereum{
 		lesCommons: lesCommons{
 		lesCommons: lesCommons{
-			chainDb: chainDb,
-			config:  config,
-			iConfig: light.DefaultClientIndexerConfig,
+			genesis:     genesisHash,
+			config:      config,
+			chainConfig: chainConfig,
+			iConfig:     light.DefaultClientIndexerConfig,
+			chainDb:     chainDb,
+			peers:       peers,
+			closeCh:     make(chan struct{}),
 		},
 		},
-		chainConfig:    chainConfig,
 		eventMux:       ctx.EventMux,
 		eventMux:       ctx.EventMux,
-		peers:          peers,
-		reqDist:        newRequestDistributor(peers, quitSync, &mclock.System{}),
+		reqDist:        newRequestDistributor(peers, &mclock.System{}),
 		accountManager: ctx.AccountManager,
 		accountManager: ctx.AccountManager,
 		engine:         eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb),
 		engine:         eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb),
-		shutdownChan:   make(chan bool),
-		networkId:      config.NetworkId,
 		bloomRequests:  make(chan chan *bloombits.Retrieval),
 		bloomRequests:  make(chan chan *bloombits.Retrieval),
 		bloomIndexer:   eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
 		bloomIndexer:   eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
+		serverPool:     newServerPool(chainDb, config.UltraLightServers),
 	}
 	}
-	leth.serverPool = newServerPool(chainDb, quitSync, &leth.wg, leth.config.UltraLightServers)
 	leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool)
 	leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool)
 	leth.relay = newLesTxRelay(peers, leth.retriever)
 	leth.relay = newLesTxRelay(peers, leth.retriever)
 
 
@@ -128,11 +114,26 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
 	if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil {
 	if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+	leth.chainReader = leth.blockchain
+	leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay)
+
+	// Set up checkpoint oracle.
+	oracle := config.CheckpointOracle
+	if oracle == nil {
+		oracle = params.CheckpointOracles[genesisHash]
+	}
+	leth.oracle = newCheckpointOracle(oracle, leth.localCheckpoint)
+
 	// Note: AddChildIndexer starts the update process for the child
 	// Note: AddChildIndexer starts the update process for the child
 	leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer)
 	leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer)
 	leth.chtIndexer.Start(leth.blockchain)
 	leth.chtIndexer.Start(leth.blockchain)
 	leth.bloomIndexer.Start(leth.blockchain)
 	leth.bloomIndexer.Start(leth.blockchain)
 
 
+	leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
+	if leth.handler.ulc != nil {
+		log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
+		leth.blockchain.DisableCheckFreq()
+	}
 	// Rewind the chain in case of an incompatible config upgrade.
 	// Rewind the chain in case of an incompatible config upgrade.
 	if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
 	if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
 		log.Warn("Rewinding chain to upgrade configuration", "err", compat)
 		log.Warn("Rewinding chain to upgrade configuration", "err", compat)
@@ -140,41 +141,16 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
 		rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig)
 		rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig)
 	}
 	}
 
 
-	leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay)
 	leth.ApiBackend = &LesApiBackend{ctx.ExtRPCEnabled(), leth, nil}
 	leth.ApiBackend = &LesApiBackend{ctx.ExtRPCEnabled(), leth, nil}
-
 	gpoParams := config.GPO
 	gpoParams := config.GPO
 	if gpoParams.Default == nil {
 	if gpoParams.Default == nil {
 		gpoParams.Default = config.Miner.GasPrice
 		gpoParams.Default = config.Miner.GasPrice
 	}
 	}
 	leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams)
 	leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams)
 
 
-	oracle := config.CheckpointOracle
-	if oracle == nil {
-		oracle = params.CheckpointOracles[genesisHash]
-	}
-	registrar := newCheckpointOracle(oracle, leth.getLocalCheckpoint)
-	if leth.protocolManager, err = NewProtocolManager(leth.chainConfig, checkpoint, light.DefaultClientIndexerConfig, config.UltraLightServers, config.UltraLightFraction, true, config.NetworkId, leth.eventMux, leth.peers, leth.blockchain, nil, chainDb, leth.odr, leth.serverPool, registrar, quitSync, &leth.wg, nil); err != nil {
-		return nil, err
-	}
-	if leth.protocolManager.ulc != nil {
-		log.Warn("Ultra light client is enabled", "servers", len(config.UltraLightServers), "fraction", config.UltraLightFraction)
-		leth.blockchain.DisableCheckFreq()
-	}
 	return leth, nil
 	return leth, nil
 }
 }
 
 
-func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic {
-	var name string
-	switch protocolVersion {
-	case lpv2:
-		name = "LES2"
-	default:
-		panic(nil)
-	}
-	return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8]))
-}
-
 type LightDummyAPI struct{}
 type LightDummyAPI struct{}
 
 
 // Etherbase is the address that mining rewards will be send to
 // Etherbase is the address that mining rewards will be send to
@@ -209,7 +185,7 @@ func (s *LightEthereum) APIs() []rpc.API {
 		}, {
 		}, {
 			Namespace: "eth",
 			Namespace: "eth",
 			Version:   "1.0",
 			Version:   "1.0",
-			Service:   downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux),
+			Service:   downloader.NewPublicDownloaderAPI(s.handler.downloader, s.eventMux),
 			Public:    true,
 			Public:    true,
 		}, {
 		}, {
 			Namespace: "eth",
 			Namespace: "eth",
@@ -224,7 +200,7 @@ func (s *LightEthereum) APIs() []rpc.API {
 		}, {
 		}, {
 			Namespace: "les",
 			Namespace: "les",
 			Version:   "1.0",
 			Version:   "1.0",
-			Service:   NewPrivateLightAPI(&s.lesCommons, s.protocolManager.reg),
+			Service:   NewPrivateLightAPI(&s.lesCommons),
 			Public:    false,
 			Public:    false,
 		},
 		},
 	}...)
 	}...)
@@ -238,54 +214,63 @@ func (s *LightEthereum) BlockChain() *light.LightChain      { return s.blockchai
 func (s *LightEthereum) TxPool() *light.TxPool              { return s.txPool }
 func (s *LightEthereum) TxPool() *light.TxPool              { return s.txPool }
 func (s *LightEthereum) Engine() consensus.Engine           { return s.engine }
 func (s *LightEthereum) Engine() consensus.Engine           { return s.engine }
 func (s *LightEthereum) LesVersion() int                    { return int(ClientProtocolVersions[0]) }
 func (s *LightEthereum) LesVersion() int                    { return int(ClientProtocolVersions[0]) }
-func (s *LightEthereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader }
+func (s *LightEthereum) Downloader() *downloader.Downloader { return s.handler.downloader }
 func (s *LightEthereum) EventMux() *event.TypeMux           { return s.eventMux }
 func (s *LightEthereum) EventMux() *event.TypeMux           { return s.eventMux }
 
 
 // Protocols implements node.Service, returning all the currently configured
 // Protocols implements node.Service, returning all the currently configured
 // network protocols to start.
 // network protocols to start.
 func (s *LightEthereum) Protocols() []p2p.Protocol {
 func (s *LightEthereum) Protocols() []p2p.Protocol {
-	return s.makeProtocols(ClientProtocolVersions)
+	return s.makeProtocols(ClientProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} {
+		if p := s.peers.Peer(peerIdToString(id)); p != nil {
+			return p.Info()
+		}
+		return nil
+	})
 }
 }
 
 
 // Start implements node.Service, starting all internal goroutines needed by the
 // Start implements node.Service, starting all internal goroutines needed by the
-// Ethereum protocol implementation.
+// light ethereum protocol implementation.
 func (s *LightEthereum) Start(srvr *p2p.Server) error {
 func (s *LightEthereum) Start(srvr *p2p.Server) error {
 	log.Warn("Light client mode is an experimental feature")
 	log.Warn("Light client mode is an experimental feature")
+
+	// Start bloom request workers.
+	s.wg.Add(bloomServiceThreads)
 	s.startBloomHandlers(params.BloomBitsBlocksClient)
 	s.startBloomHandlers(params.BloomBitsBlocksClient)
-	s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId)
+
+	s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId)
+
 	// clients are searching for the first advertised protocol in the list
 	// clients are searching for the first advertised protocol in the list
 	protocolVersion := AdvertiseProtocolVersions[0]
 	protocolVersion := AdvertiseProtocolVersions[0]
 	s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion))
 	s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion))
-	s.protocolManager.Start(s.config.LightPeers)
 	return nil
 	return nil
 }
 }
 
 
 // Stop implements node.Service, terminating all internal goroutines used by the
 // Stop implements node.Service, terminating all internal goroutines used by the
 // Ethereum protocol.
 // Ethereum protocol.
 func (s *LightEthereum) Stop() error {
 func (s *LightEthereum) Stop() error {
+	close(s.closeCh)
+	s.peers.Close()
+	s.reqDist.close()
 	s.odr.Stop()
 	s.odr.Stop()
 	s.relay.Stop()
 	s.relay.Stop()
 	s.bloomIndexer.Close()
 	s.bloomIndexer.Close()
 	s.chtIndexer.Close()
 	s.chtIndexer.Close()
 	s.blockchain.Stop()
 	s.blockchain.Stop()
-	s.protocolManager.Stop()
+	s.handler.stop()
 	s.txPool.Stop()
 	s.txPool.Stop()
 	s.engine.Close()
 	s.engine.Close()
-
 	s.eventMux.Stop()
 	s.eventMux.Stop()
-
-	time.Sleep(time.Millisecond * 200)
+	s.serverPool.stop()
 	s.chainDb.Close()
 	s.chainDb.Close()
-	close(s.shutdownChan)
-
+	s.wg.Wait()
+	log.Info("Light ethereum stopped")
 	return nil
 	return nil
 }
 }
 
 
 // SetClient sets the rpc client and binds the registrar contract.
 // SetClient sets the rpc client and binds the registrar contract.
 func (s *LightEthereum) SetContractBackend(backend bind.ContractBackend) {
 func (s *LightEthereum) SetContractBackend(backend bind.ContractBackend) {
-	// Short circuit if registrar is nil
-	if s.protocolManager.reg == nil {
+	if s.oracle == nil {
 		return
 		return
 	}
 	}
-	s.protocolManager.reg.start(backend)
+	s.oracle.start(backend)
 }
 }

+ 401 - 0
les/client_handler.go

@@ -0,0 +1,401 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package les
+
+import (
+	"math/big"
+	"sync"
+	"time"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/common/mclock"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/eth/downloader"
+	"github.com/ethereum/go-ethereum/light"
+	"github.com/ethereum/go-ethereum/log"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/params"
+)
+
+// clientHandler is responsible for receiving and processing all incoming server
+// responses.
+type clientHandler struct {
+	ulc        *ulc
+	checkpoint *params.TrustedCheckpoint
+	fetcher    *lightFetcher
+	downloader *downloader.Downloader
+	backend    *LightEthereum
+
+	closeCh chan struct{}
+	wg      sync.WaitGroup // WaitGroup used to track all connected peers.
+}
+
+func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler {
+	handler := &clientHandler{
+		backend: backend,
+		closeCh: make(chan struct{}),
+	}
+	if ulcServers != nil {
+		ulc, err := newULC(ulcServers, ulcFraction)
+		if err != nil {
+			log.Error("Failed to initialize ultra light client")
+		}
+		handler.ulc = ulc
+		log.Info("Enable ultra light client mode")
+	}
+	var height uint64
+	if checkpoint != nil {
+		height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1
+	}
+	handler.fetcher = newLightFetcher(handler)
+	handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer)
+	handler.backend.peers.notify((*downloaderPeerNotify)(handler))
+	return handler
+}
+
+func (h *clientHandler) stop() {
+	close(h.closeCh)
+	h.downloader.Terminate()
+	h.fetcher.close()
+	h.wg.Wait()
+}
+
+// runPeer is the p2p protocol run function for the given version.
+func (h *clientHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
+	trusted := false
+	if h.ulc != nil {
+		trusted = h.ulc.trusted(p.ID())
+	}
+	peer := newPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version)))
+	peer.poolEntry = h.backend.serverPool.connect(peer, peer.Node())
+	if peer.poolEntry == nil {
+		return p2p.DiscRequested
+	}
+	h.wg.Add(1)
+	defer h.wg.Done()
+	err := h.handle(peer)
+	h.backend.serverPool.disconnect(peer.poolEntry)
+	return err
+}
+
+func (h *clientHandler) handle(p *peer) error {
+	if h.backend.peers.Len() >= h.backend.config.LightPeers && !p.Peer.Info().Network.Trusted {
+		return p2p.DiscTooManyPeers
+	}
+	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
+
+	// Execute the LES handshake
+	var (
+		head   = h.backend.blockchain.CurrentHeader()
+		hash   = head.Hash()
+		number = head.Number.Uint64()
+		td     = h.backend.blockchain.GetTd(hash, number)
+	)
+	if err := p.Handshake(td, hash, number, h.backend.blockchain.Genesis().Hash(), nil); err != nil {
+		p.Log().Debug("Light Ethereum handshake failed", "err", err)
+		return err
+	}
+	// Register the peer locally
+	if err := h.backend.peers.Register(p); err != nil {
+		p.Log().Error("Light Ethereum peer registration failed", "err", err)
+		return err
+	}
+	serverConnectionGauge.Update(int64(h.backend.peers.Len()))
+
+	connectedAt := mclock.Now()
+	defer func() {
+		h.backend.peers.Unregister(p.id)
+		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
+		serverConnectionGauge.Update(int64(h.backend.peers.Len()))
+	}()
+
+	h.fetcher.announce(p, p.headInfo)
+
+	// pool entry can be nil during the unit test.
+	if p.poolEntry != nil {
+		h.backend.serverPool.registered(p.poolEntry)
+	}
+	// Spawn a main loop to handle all incoming messages.
+	for {
+		if err := h.handleMsg(p); err != nil {
+			p.Log().Debug("Light Ethereum message handling failed", "err", err)
+			p.fcServer.DumpLogs()
+			return err
+		}
+	}
+}
+
+// handleMsg is invoked whenever an inbound message is received from a remote
+// peer. The remote connection is torn down upon returning any error.
+func (h *clientHandler) handleMsg(p *peer) error {
+	// Read the next message from the remote peer, and ensure it's fully consumed
+	msg, err := p.rw.ReadMsg()
+	if err != nil {
+		return err
+	}
+	p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
+
+	if msg.Size > ProtocolMaxMsgSize {
+		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
+	}
+	defer msg.Discard()
+
+	var deliverMsg *Msg
+
+	// Handle the message depending on its contents
+	switch msg.Code {
+	case AnnounceMsg:
+		p.Log().Trace("Received announce message")
+		var req announceData
+		if err := msg.Decode(&req); err != nil {
+			return errResp(ErrDecode, "%v: %v", msg, err)
+		}
+		if err := req.sanityCheck(); err != nil {
+			return err
+		}
+		update, size := req.Update.decode()
+		if p.rejectUpdate(size) {
+			return errResp(ErrRequestRejected, "")
+		}
+		p.updateFlowControl(update)
+
+		if req.Hash != (common.Hash{}) {
+			if p.announceType == announceTypeNone {
+				return errResp(ErrUnexpectedResponse, "")
+			}
+			if p.announceType == announceTypeSigned {
+				if err := req.checkSignature(p.ID(), update); err != nil {
+					p.Log().Trace("Invalid announcement signature", "err", err)
+					return err
+				}
+				p.Log().Trace("Valid announcement signature")
+			}
+			p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth)
+			h.fetcher.announce(p, &req)
+		}
+	case BlockHeadersMsg:
+		p.Log().Trace("Received block header response message")
+		var resp struct {
+			ReqID, BV uint64
+			Headers   []*types.Header
+		}
+		if err := msg.Decode(&resp); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+		if h.fetcher.requestedID(resp.ReqID) {
+			h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers)
+		} else {
+			if err := h.downloader.DeliverHeaders(p.id, resp.Headers); err != nil {
+				log.Debug("Failed to deliver headers", "err", err)
+			}
+		}
+	case BlockBodiesMsg:
+		p.Log().Trace("Received block bodies response")
+		var resp struct {
+			ReqID, BV uint64
+			Data      []*types.Body
+		}
+		if err := msg.Decode(&resp); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+		deliverMsg = &Msg{
+			MsgType: MsgBlockBodies,
+			ReqID:   resp.ReqID,
+			Obj:     resp.Data,
+		}
+	case CodeMsg:
+		p.Log().Trace("Received code response")
+		var resp struct {
+			ReqID, BV uint64
+			Data      [][]byte
+		}
+		if err := msg.Decode(&resp); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+		deliverMsg = &Msg{
+			MsgType: MsgCode,
+			ReqID:   resp.ReqID,
+			Obj:     resp.Data,
+		}
+	case ReceiptsMsg:
+		p.Log().Trace("Received receipts response")
+		var resp struct {
+			ReqID, BV uint64
+			Receipts  []types.Receipts
+		}
+		if err := msg.Decode(&resp); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+		deliverMsg = &Msg{
+			MsgType: MsgReceipts,
+			ReqID:   resp.ReqID,
+			Obj:     resp.Receipts,
+		}
+	case ProofsV2Msg:
+		p.Log().Trace("Received les/2 proofs response")
+		var resp struct {
+			ReqID, BV uint64
+			Data      light.NodeList
+		}
+		if err := msg.Decode(&resp); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+		deliverMsg = &Msg{
+			MsgType: MsgProofsV2,
+			ReqID:   resp.ReqID,
+			Obj:     resp.Data,
+		}
+	case HelperTrieProofsMsg:
+		p.Log().Trace("Received helper trie proof response")
+		var resp struct {
+			ReqID, BV uint64
+			Data      HelperTrieResps
+		}
+		if err := msg.Decode(&resp); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+		deliverMsg = &Msg{
+			MsgType: MsgHelperTrieProofs,
+			ReqID:   resp.ReqID,
+			Obj:     resp.Data,
+		}
+	case TxStatusMsg:
+		p.Log().Trace("Received tx status response")
+		var resp struct {
+			ReqID, BV uint64
+			Status    []light.TxStatus
+		}
+		if err := msg.Decode(&resp); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
+		deliverMsg = &Msg{
+			MsgType: MsgTxStatus,
+			ReqID:   resp.ReqID,
+			Obj:     resp.Status,
+		}
+	case StopMsg:
+		p.freezeServer(true)
+		h.backend.retriever.frozen(p)
+		p.Log().Debug("Service stopped")
+	case ResumeMsg:
+		var bv uint64
+		if err := msg.Decode(&bv); err != nil {
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		p.fcServer.ResumeFreeze(bv)
+		p.freezeServer(false)
+		p.Log().Debug("Service resumed")
+	default:
+		p.Log().Trace("Received invalid message", "code", msg.Code)
+		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
+	}
+	// Deliver the received response to retriever.
+	if deliverMsg != nil {
+		if err := h.backend.retriever.deliver(p, deliverMsg); err != nil {
+			p.responseErrors++
+			if p.responseErrors > maxResponseErrors {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func (h *clientHandler) removePeer(id string) {
+	h.backend.peers.Unregister(id)
+}
+
+type peerConnection struct {
+	handler *clientHandler
+	peer    *peer
+}
+
+func (pc *peerConnection) Head() (common.Hash, *big.Int) {
+	return pc.peer.HeadAndTd()
+}
+
+func (pc *peerConnection) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
+	rq := &distReq{
+		getCost: func(dp distPeer) uint64 {
+			peer := dp.(*peer)
+			return peer.GetRequestCost(GetBlockHeadersMsg, amount)
+		},
+		canSend: func(dp distPeer) bool {
+			return dp.(*peer) == pc.peer
+		},
+		request: func(dp distPeer) func() {
+			reqID := genReqID()
+			peer := dp.(*peer)
+			cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
+			peer.fcServer.QueuedRequest(reqID, cost)
+			return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) }
+		},
+	}
+	_, ok := <-pc.handler.backend.reqDist.queue(rq)
+	if !ok {
+		return light.ErrNoPeers
+	}
+	return nil
+}
+
+func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
+	rq := &distReq{
+		getCost: func(dp distPeer) uint64 {
+			peer := dp.(*peer)
+			return peer.GetRequestCost(GetBlockHeadersMsg, amount)
+		},
+		canSend: func(dp distPeer) bool {
+			return dp.(*peer) == pc.peer
+		},
+		request: func(dp distPeer) func() {
+			reqID := genReqID()
+			peer := dp.(*peer)
+			cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
+			peer.fcServer.QueuedRequest(reqID, cost)
+			return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) }
+		},
+	}
+	_, ok := <-pc.handler.backend.reqDist.queue(rq)
+	if !ok {
+		return light.ErrNoPeers
+	}
+	return nil
+}
+
+// downloaderPeerNotify implements peerSetNotify
+type downloaderPeerNotify clientHandler
+
+func (d *downloaderPeerNotify) registerPeer(p *peer) {
+	h := (*clientHandler)(d)
+	pc := &peerConnection{
+		handler: h,
+		peer:    p,
+	}
+	h.downloader.RegisterLightPeer(p.id, ethVersion, pc)
+}
+
+func (d *downloaderPeerNotify) unregisterPeer(p *peer) {
+	h := (*clientHandler)(d)
+	h.downloader.UnregisterPeer(p.id)
+}

+ 47 - 22
les/commons.go

@@ -17,25 +17,56 @@
 package les
 package les
 
 
 import (
 import (
+	"fmt"
 	"math/big"
 	"math/big"
+	"sync"
 
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/eth"
 	"github.com/ethereum/go-ethereum/eth"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/light"
 	"github.com/ethereum/go-ethereum/light"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/discv5"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/params"
 	"github.com/ethereum/go-ethereum/params"
 )
 )
 
 
+func errResp(code errCode, format string, v ...interface{}) error {
+	return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...))
+}
+
+func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic {
+	var name string
+	switch protocolVersion {
+	case lpv2:
+		name = "LES2"
+	default:
+		panic(nil)
+	}
+	return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8]))
+}
+
+type chainReader interface {
+	CurrentHeader() *types.Header
+}
+
 // lesCommons contains fields needed by both server and client.
 // lesCommons contains fields needed by both server and client.
 type lesCommons struct {
 type lesCommons struct {
+	genesis                      common.Hash
 	config                       *eth.Config
 	config                       *eth.Config
+	chainConfig                  *params.ChainConfig
 	iConfig                      *light.IndexerConfig
 	iConfig                      *light.IndexerConfig
 	chainDb                      ethdb.Database
 	chainDb                      ethdb.Database
-	protocolManager              *ProtocolManager
+	peers                        *peerSet
+	chainReader                  chainReader
 	chtIndexer, bloomTrieIndexer *core.ChainIndexer
 	chtIndexer, bloomTrieIndexer *core.ChainIndexer
+	oracle                       *checkpointOracle
+
+	closeCh chan struct{}
+	wg      sync.WaitGroup
 }
 }
 
 
 // NodeInfo represents a short summary of the Ethereum sub-protocol metadata
 // NodeInfo represents a short summary of the Ethereum sub-protocol metadata
@@ -50,7 +81,7 @@ type NodeInfo struct {
 }
 }
 
 
 // makeProtocols creates protocol descriptors for the given LES versions.
 // makeProtocols creates protocol descriptors for the given LES versions.
-func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol {
+func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol {
 	protos := make([]p2p.Protocol, len(versions))
 	protos := make([]p2p.Protocol, len(versions))
 	for i, version := range versions {
 	for i, version := range versions {
 		version := version
 		version := version
@@ -59,15 +90,10 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol {
 			Version:  version,
 			Version:  version,
 			Length:   ProtocolLengths[version],
 			Length:   ProtocolLengths[version],
 			NodeInfo: c.nodeInfo,
 			NodeInfo: c.nodeInfo,
-			Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
-				return c.protocolManager.runPeer(version, p, rw)
-			},
-			PeerInfo: func(id enode.ID) interface{} {
-				if p := c.protocolManager.peers.Peer(peerIdToString(id)); p != nil {
-					return p.Info()
-				}
-				return nil
+			Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
+				return runPeer(version, peer, rw)
 			},
 			},
+			PeerInfo: peerInfo,
 		}
 		}
 	}
 	}
 	return protos
 	return protos
@@ -75,22 +101,21 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol {
 
 
 // nodeInfo retrieves some protocol metadata about the running host node.
 // nodeInfo retrieves some protocol metadata about the running host node.
 func (c *lesCommons) nodeInfo() interface{} {
 func (c *lesCommons) nodeInfo() interface{} {
-	chain := c.protocolManager.blockchain
-	head := chain.CurrentHeader()
+	head := c.chainReader.CurrentHeader()
 	hash := head.Hash()
 	hash := head.Hash()
 	return &NodeInfo{
 	return &NodeInfo{
 		Network:    c.config.NetworkId,
 		Network:    c.config.NetworkId,
-		Difficulty: chain.GetTd(hash, head.Number.Uint64()),
-		Genesis:    chain.Genesis().Hash(),
-		Config:     chain.Config(),
-		Head:       chain.CurrentHeader().Hash(),
+		Difficulty: rawdb.ReadTd(c.chainDb, hash, head.Number.Uint64()),
+		Genesis:    c.genesis,
+		Config:     c.chainConfig,
+		Head:       hash,
 		CHT:        c.latestLocalCheckpoint(),
 		CHT:        c.latestLocalCheckpoint(),
 	}
 	}
 }
 }
 
 
-// latestLocalCheckpoint finds the common stored section index and returns a set of
-// post-processed trie roots (CHT and BloomTrie) associated with
-// the appropriate section index and head hash as a local checkpoint package.
+// latestLocalCheckpoint finds the common stored section index and returns a set
+// of post-processed trie roots (CHT and BloomTrie) associated with the appropriate
+// section index and head hash as a local checkpoint package.
 func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint {
 func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint {
 	sections, _, _ := c.chtIndexer.Sections()
 	sections, _, _ := c.chtIndexer.Sections()
 	sections2, _, _ := c.bloomTrieIndexer.Sections()
 	sections2, _, _ := c.bloomTrieIndexer.Sections()
@@ -102,15 +127,15 @@ func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint {
 		// No checkpoint information can be provided.
 		// No checkpoint information can be provided.
 		return params.TrustedCheckpoint{}
 		return params.TrustedCheckpoint{}
 	}
 	}
-	return c.getLocalCheckpoint(sections - 1)
+	return c.localCheckpoint(sections - 1)
 }
 }
 
 
-// getLocalCheckpoint returns a set of post-processed trie roots (CHT and BloomTrie)
+// localCheckpoint returns a set of post-processed trie roots (CHT and BloomTrie)
 // associated with the appropriate head hash by specific section index.
 // associated with the appropriate head hash by specific section index.
 //
 //
 // The returned checkpoint is only the checkpoint generated by the local indexers,
 // The returned checkpoint is only the checkpoint generated by the local indexers,
 // not the stable checkpoint registered in the registrar contract.
 // not the stable checkpoint registered in the registrar contract.
-func (c *lesCommons) getLocalCheckpoint(index uint64) params.TrustedCheckpoint {
+func (c *lesCommons) localCheckpoint(index uint64) params.TrustedCheckpoint {
 	sectionHead := c.chtIndexer.SectionHead(index)
 	sectionHead := c.chtIndexer.SectionHead(index)
 	return params.TrustedCheckpoint{
 	return params.TrustedCheckpoint{
 		SectionIndex: index,
 		SectionIndex: index,

+ 8 - 3
les/costtracker.go

@@ -81,7 +81,8 @@ var (
 )
 )
 
 
 const (
 const (
-	maxCostFactor    = 2 // ratio of maximum and average cost estimates
+	maxCostFactor    = 2    // ratio of maximum and average cost estimates
+	bufLimitRatio    = 6000 // fixed bufLimit/MRR ratio
 	gfUsageThreshold = 0.5
 	gfUsageThreshold = 0.5
 	gfUsageTC        = time.Second
 	gfUsageTC        = time.Second
 	gfRaiseTC        = time.Second * 200
 	gfRaiseTC        = time.Second * 200
@@ -127,6 +128,10 @@ type costTracker struct {
 	totalRechargeCh chan uint64
 	totalRechargeCh chan uint64
 
 
 	stats map[uint64][]uint64 // Used for testing purpose.
 	stats map[uint64][]uint64 // Used for testing purpose.
+
+	// TestHooks
+	testing      bool            // Disable real cost evaluation for testing purpose.
+	testCostList RequestCostList // Customized cost table for testing purpose.
 }
 }
 
 
 // newCostTracker creates a cost tracker and loads the cost factor statistics from the database.
 // newCostTracker creates a cost tracker and loads the cost factor statistics from the database.
@@ -265,8 +270,9 @@ func (ct *costTracker) gfLoop() {
 			select {
 			select {
 			case r := <-ct.reqInfoCh:
 			case r := <-ct.reqInfoCh:
 				requestServedMeter.Mark(int64(r.servingTime))
 				requestServedMeter.Mark(int64(r.servingTime))
-				requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor))
 				requestServedTimer.Update(time.Duration(r.servingTime))
 				requestServedTimer.Update(time.Duration(r.servingTime))
+				requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor))
+				requestEstimatedTimer.Update(time.Duration(r.avgTimeCost / factor))
 				relativeCostHistogram.Update(int64(r.avgTimeCost / factor / r.servingTime))
 				relativeCostHistogram.Update(int64(r.avgTimeCost / factor / r.servingTime))
 
 
 				now := mclock.Now()
 				now := mclock.Now()
@@ -323,7 +329,6 @@ func (ct *costTracker) gfLoop() {
 				}
 				}
 				recentServedGauge.Update(int64(recentTime))
 				recentServedGauge.Update(int64(recentTime))
 				recentEstimatedGauge.Update(int64(recentAvg))
 				recentEstimatedGauge.Update(int64(recentAvg))
-				totalRechargeGauge.Update(int64(totalRecharge))
 
 
 			case <-saveTicker.C:
 			case <-saveTicker.C:
 				saveCostFactor()
 				saveCostFactor()

+ 26 - 11
les/distributor.go

@@ -28,14 +28,17 @@ import (
 // suitable peers, obeying flow control rules and prioritizing them in creation
 // suitable peers, obeying flow control rules and prioritizing them in creation
 // order (even when a resend is necessary).
 // order (even when a resend is necessary).
 type requestDistributor struct {
 type requestDistributor struct {
-	clock            mclock.Clock
-	reqQueue         *list.List
-	lastReqOrder     uint64
-	peers            map[distPeer]struct{}
-	peerLock         sync.RWMutex
-	stopChn, loopChn chan struct{}
-	loopNextSent     bool
-	lock             sync.Mutex
+	clock        mclock.Clock
+	reqQueue     *list.List
+	lastReqOrder uint64
+	peers        map[distPeer]struct{}
+	peerLock     sync.RWMutex
+	loopChn      chan struct{}
+	loopNextSent bool
+	lock         sync.Mutex
+
+	closeCh chan struct{}
+	wg      sync.WaitGroup
 }
 }
 
 
 // distPeer is an LES server peer interface for the request distributor.
 // distPeer is an LES server peer interface for the request distributor.
@@ -66,20 +69,22 @@ type distReq struct {
 	sentChn      chan distPeer
 	sentChn      chan distPeer
 	element      *list.Element
 	element      *list.Element
 	waitForPeers mclock.AbsTime
 	waitForPeers mclock.AbsTime
+	enterQueue   mclock.AbsTime
 }
 }
 
 
 // newRequestDistributor creates a new request distributor
 // newRequestDistributor creates a new request distributor
-func newRequestDistributor(peers *peerSet, stopChn chan struct{}, clock mclock.Clock) *requestDistributor {
+func newRequestDistributor(peers *peerSet, clock mclock.Clock) *requestDistributor {
 	d := &requestDistributor{
 	d := &requestDistributor{
 		clock:    clock,
 		clock:    clock,
 		reqQueue: list.New(),
 		reqQueue: list.New(),
 		loopChn:  make(chan struct{}, 2),
 		loopChn:  make(chan struct{}, 2),
-		stopChn:  stopChn,
+		closeCh:  make(chan struct{}),
 		peers:    make(map[distPeer]struct{}),
 		peers:    make(map[distPeer]struct{}),
 	}
 	}
 	if peers != nil {
 	if peers != nil {
 		peers.notify(d)
 		peers.notify(d)
 	}
 	}
+	d.wg.Add(1)
 	go d.loop()
 	go d.loop()
 	return d
 	return d
 }
 }
@@ -115,9 +120,10 @@ const waitForPeers = time.Second * 3
 
 
 // main event loop
 // main event loop
 func (d *requestDistributor) loop() {
 func (d *requestDistributor) loop() {
+	defer d.wg.Done()
 	for {
 	for {
 		select {
 		select {
-		case <-d.stopChn:
+		case <-d.closeCh:
 			d.lock.Lock()
 			d.lock.Lock()
 			elem := d.reqQueue.Front()
 			elem := d.reqQueue.Front()
 			for elem != nil {
 			for elem != nil {
@@ -140,6 +146,7 @@ func (d *requestDistributor) loop() {
 					send := req.request(peer)
 					send := req.request(peer)
 					if send != nil {
 					if send != nil {
 						peer.queueSend(send)
 						peer.queueSend(send)
+						requestSendDelay.Update(time.Duration(d.clock.Now() - req.enterQueue))
 					}
 					}
 					chn <- peer
 					chn <- peer
 					close(chn)
 					close(chn)
@@ -249,6 +256,9 @@ func (d *requestDistributor) queue(r *distReq) chan distPeer {
 		r.reqOrder = d.lastReqOrder
 		r.reqOrder = d.lastReqOrder
 		r.waitForPeers = d.clock.Now() + mclock.AbsTime(waitForPeers)
 		r.waitForPeers = d.clock.Now() + mclock.AbsTime(waitForPeers)
 	}
 	}
+	// Assign the timestamp when the request is queued no matter it's
+	// a new one or re-queued one.
+	r.enterQueue = d.clock.Now()
 
 
 	back := d.reqQueue.Back()
 	back := d.reqQueue.Back()
 	if back == nil || r.reqOrder > back.Value.(*distReq).reqOrder {
 	if back == nil || r.reqOrder > back.Value.(*distReq).reqOrder {
@@ -294,3 +304,8 @@ func (d *requestDistributor) remove(r *distReq) {
 		r.element = nil
 		r.element = nil
 	}
 	}
 }
 }
+
+func (d *requestDistributor) close() {
+	close(d.closeCh)
+	d.wg.Wait()
+}

+ 1 - 1
les/distributor_test.go

@@ -121,7 +121,7 @@ func testRequestDistributor(t *testing.T, resend bool) {
 	stop := make(chan struct{})
 	stop := make(chan struct{})
 	defer close(stop)
 	defer close(stop)
 
 
-	dist := newRequestDistributor(nil, stop, &mclock.System{})
+	dist := newRequestDistributor(nil, &mclock.System{})
 	var peers [testDistPeerCount]*testDistPeer
 	var peers [testDistPeerCount]*testDistPeer
 	for i := range peers {
 	for i := range peers {
 		peers[i] = &testDistPeer{}
 		peers[i] = &testDistPeer{}

+ 37 - 38
les/fetcher.go

@@ -40,9 +40,8 @@ const (
 // ODR system to ensure that we only request data related to a certain block from peers who have already processed
 // ODR system to ensure that we only request data related to a certain block from peers who have already processed
 // and announced that block.
 // and announced that block.
 type lightFetcher struct {
 type lightFetcher struct {
-	pm    *ProtocolManager
-	odr   *LesOdr
-	chain lightChain
+	handler *clientHandler
+	chain   *light.LightChain
 
 
 	lock            sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests
 	lock            sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests
 	maxConfirmedTd  *big.Int
 	maxConfirmedTd  *big.Int
@@ -58,13 +57,9 @@ type lightFetcher struct {
 	requestTriggered  bool
 	requestTriggered  bool
 	requestTrigger    chan struct{}
 	requestTrigger    chan struct{}
 	lastTrustedHeader *types.Header
 	lastTrustedHeader *types.Header
-}
 
 
-// lightChain extends the BlockChain interface by locking.
-type lightChain interface {
-	BlockChain
-	LockChain()
-	UnlockChain()
+	closeCh chan struct{}
+	wg      sync.WaitGroup
 }
 }
 
 
 // fetcherPeerInfo holds fetcher-specific information about each active peer
 // fetcherPeerInfo holds fetcher-specific information about each active peer
@@ -114,32 +109,37 @@ type fetchResponse struct {
 }
 }
 
 
 // newLightFetcher creates a new light fetcher
 // newLightFetcher creates a new light fetcher
-func newLightFetcher(pm *ProtocolManager) *lightFetcher {
+func newLightFetcher(h *clientHandler) *lightFetcher {
 	f := &lightFetcher{
 	f := &lightFetcher{
-		pm:             pm,
-		chain:          pm.blockchain.(*light.LightChain),
-		odr:            pm.odr,
+		handler:        h,
+		chain:          h.backend.blockchain,
 		peers:          make(map[*peer]*fetcherPeerInfo),
 		peers:          make(map[*peer]*fetcherPeerInfo),
 		deliverChn:     make(chan fetchResponse, 100),
 		deliverChn:     make(chan fetchResponse, 100),
 		requested:      make(map[uint64]fetchRequest),
 		requested:      make(map[uint64]fetchRequest),
 		timeoutChn:     make(chan uint64),
 		timeoutChn:     make(chan uint64),
 		requestTrigger: make(chan struct{}, 1),
 		requestTrigger: make(chan struct{}, 1),
 		syncDone:       make(chan *peer),
 		syncDone:       make(chan *peer),
+		closeCh:        make(chan struct{}),
 		maxConfirmedTd: big.NewInt(0),
 		maxConfirmedTd: big.NewInt(0),
 	}
 	}
-	pm.peers.notify(f)
+	h.backend.peers.notify(f)
 
 
-	f.pm.wg.Add(1)
+	f.wg.Add(1)
 	go f.syncLoop()
 	go f.syncLoop()
 	return f
 	return f
 }
 }
 
 
+func (f *lightFetcher) close() {
+	close(f.closeCh)
+	f.wg.Wait()
+}
+
 // syncLoop is the main event loop of the light fetcher
 // syncLoop is the main event loop of the light fetcher
 func (f *lightFetcher) syncLoop() {
 func (f *lightFetcher) syncLoop() {
-	defer f.pm.wg.Done()
+	defer f.wg.Done()
 	for {
 	for {
 		select {
 		select {
-		case <-f.pm.quitSync:
+		case <-f.closeCh:
 			return
 			return
 		// request loop keeps running until no further requests are necessary or possible
 		// request loop keeps running until no further requests are necessary or possible
 		case <-f.requestTrigger:
 		case <-f.requestTrigger:
@@ -156,7 +156,7 @@ func (f *lightFetcher) syncLoop() {
 			f.lock.Unlock()
 			f.lock.Unlock()
 
 
 			if rq != nil {
 			if rq != nil {
-				if _, ok := <-f.pm.reqDist.queue(rq); ok {
+				if _, ok := <-f.handler.backend.reqDist.queue(rq); ok {
 					if syncing {
 					if syncing {
 						f.lock.Lock()
 						f.lock.Lock()
 						f.syncing = true
 						f.syncing = true
@@ -187,9 +187,9 @@ func (f *lightFetcher) syncLoop() {
 			}
 			}
 			f.reqMu.Unlock()
 			f.reqMu.Unlock()
 			if ok {
 			if ok {
-				f.pm.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true)
+				f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true)
 				req.peer.Log().Debug("Fetching data timed out hard")
 				req.peer.Log().Debug("Fetching data timed out hard")
-				go f.pm.removePeer(req.peer.id)
+				go f.handler.removePeer(req.peer.id)
 			}
 			}
 		case resp := <-f.deliverChn:
 		case resp := <-f.deliverChn:
 			f.reqMu.Lock()
 			f.reqMu.Lock()
@@ -202,12 +202,12 @@ func (f *lightFetcher) syncLoop() {
 			}
 			}
 			f.reqMu.Unlock()
 			f.reqMu.Unlock()
 			if ok {
 			if ok {
-				f.pm.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout)
+				f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout)
 			}
 			}
 			f.lock.Lock()
 			f.lock.Lock()
 			if !ok || !(f.syncing || f.processResponse(req, resp)) {
 			if !ok || !(f.syncing || f.processResponse(req, resp)) {
 				resp.peer.Log().Debug("Failed processing response")
 				resp.peer.Log().Debug("Failed processing response")
-				go f.pm.removePeer(resp.peer.id)
+				go f.handler.removePeer(resp.peer.id)
 			}
 			}
 			f.lock.Unlock()
 			f.lock.Unlock()
 		case p := <-f.syncDone:
 		case p := <-f.syncDone:
@@ -264,7 +264,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) {
 	if fp.lastAnnounced != nil && head.Td.Cmp(fp.lastAnnounced.td) <= 0 {
 	if fp.lastAnnounced != nil && head.Td.Cmp(fp.lastAnnounced.td) <= 0 {
 		// announced tds should be strictly monotonic
 		// announced tds should be strictly monotonic
 		p.Log().Debug("Received non-monotonic td", "current", head.Td, "previous", fp.lastAnnounced.td)
 		p.Log().Debug("Received non-monotonic td", "current", head.Td, "previous", fp.lastAnnounced.td)
-		go f.pm.removePeer(p.id)
+		go f.handler.removePeer(p.id)
 		return
 		return
 	}
 	}
 
 
@@ -297,7 +297,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) {
 			// if one of root's children is canonical, keep it, delete other branches and root itself
 			// if one of root's children is canonical, keep it, delete other branches and root itself
 			var newRoot *fetcherTreeNode
 			var newRoot *fetcherTreeNode
 			for i, nn := range fp.root.children {
 			for i, nn := range fp.root.children {
-				if rawdb.ReadCanonicalHash(f.pm.chainDb, nn.number) == nn.hash {
+				if rawdb.ReadCanonicalHash(f.handler.backend.chainDb, nn.number) == nn.hash {
 					fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...)
 					fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...)
 					nn.parent = nil
 					nn.parent = nil
 					newRoot = nn
 					newRoot = nn
@@ -390,7 +390,7 @@ func (f *lightFetcher) peerHasBlock(p *peer, hash common.Hash, number uint64, ha
 	//
 	//
 	// when syncing, just check if it is part of the known chain, there is nothing better we
 	// when syncing, just check if it is part of the known chain, there is nothing better we
 	// can do since we do not know the most recent block hash yet
 	// can do since we do not know the most recent block hash yet
-	return rawdb.ReadCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && rawdb.ReadCanonicalHash(f.pm.chainDb, number) == hash
+	return rawdb.ReadCanonicalHash(f.handler.backend.chainDb, fp.root.number) == fp.root.hash && rawdb.ReadCanonicalHash(f.handler.backend.chainDb, number) == hash
 }
 }
 
 
 // requestAmount calculates the amount of headers to be downloaded starting
 // requestAmount calculates the amount of headers to be downloaded starting
@@ -453,8 +453,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6
 			if f.checkKnownNode(p, n) || n.requested {
 			if f.checkKnownNode(p, n) || n.requested {
 				continue
 				continue
 			}
 			}
-
-			//if ulc mode is disabled, isTrustedHash returns true
+			// if ulc mode is disabled, isTrustedHash returns true
 			amount := f.requestAmount(p, n)
 			amount := f.requestAmount(p, n)
 			if (bestTd == nil || n.td.Cmp(bestTd) > 0 || amount < bestAmount) && (f.isTrustedHash(hash) || f.maxConfirmedTd.Int64() == 0) {
 			if (bestTd == nil || n.td.Cmp(bestTd) > 0 || amount < bestAmount) && (f.isTrustedHash(hash) || f.maxConfirmedTd.Int64() == 0) {
 				bestHash = hash
 				bestHash = hash
@@ -470,7 +469,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6
 // isTrustedHash checks if the block can be trusted by the minimum trusted fraction.
 // isTrustedHash checks if the block can be trusted by the minimum trusted fraction.
 func (f *lightFetcher) isTrustedHash(hash common.Hash) bool {
 func (f *lightFetcher) isTrustedHash(hash common.Hash) bool {
 	// If ultra light cliet mode is disabled, trust all hashes
 	// If ultra light cliet mode is disabled, trust all hashes
-	if f.pm.ulc == nil {
+	if f.handler.ulc == nil {
 		return true
 		return true
 	}
 	}
 	// Ultra light enabled, only trust after enough confirmations
 	// Ultra light enabled, only trust after enough confirmations
@@ -480,7 +479,7 @@ func (f *lightFetcher) isTrustedHash(hash common.Hash) bool {
 			agreed++
 			agreed++
 		}
 		}
 	}
 	}
-	return 100*agreed/len(f.pm.ulc.keys) >= f.pm.ulc.fraction
+	return 100*agreed/len(f.handler.ulc.keys) >= f.handler.ulc.fraction
 }
 }
 
 
 func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq {
 func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq {
@@ -500,14 +499,14 @@ func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq {
 			return fp != nil && fp.nodeByHash[bestHash] != nil
 			return fp != nil && fp.nodeByHash[bestHash] != nil
 		},
 		},
 		request: func(dp distPeer) func() {
 		request: func(dp distPeer) func() {
-			if f.pm.ulc != nil {
+			if f.handler.ulc != nil {
 				// Keep last trusted header before sync
 				// Keep last trusted header before sync
 				f.setLastTrustedHeader(f.chain.CurrentHeader())
 				f.setLastTrustedHeader(f.chain.CurrentHeader())
 			}
 			}
 			go func() {
 			go func() {
 				p := dp.(*peer)
 				p := dp.(*peer)
 				p.Log().Debug("Synchronisation started")
 				p.Log().Debug("Synchronisation started")
-				f.pm.synchronise(p)
+				f.handler.synchronise(p)
 				f.syncDone <- p
 				f.syncDone <- p
 			}()
 			}()
 			return nil
 			return nil
@@ -607,7 +606,7 @@ func (f *lightFetcher) newHeaders(headers []*types.Header, tds []*big.Int) {
 	for p, fp := range f.peers {
 	for p, fp := range f.peers {
 		if !f.checkAnnouncedHeaders(fp, headers, tds) {
 		if !f.checkAnnouncedHeaders(fp, headers, tds) {
 			p.Log().Debug("Inconsistent announcement")
 			p.Log().Debug("Inconsistent announcement")
-			go f.pm.removePeer(p.id)
+			go f.handler.removePeer(p.id)
 		}
 		}
 		if fp.confirmedTd != nil && (maxTd == nil || maxTd.Cmp(fp.confirmedTd) > 0) {
 		if fp.confirmedTd != nil && (maxTd == nil || maxTd.Cmp(fp.confirmedTd) > 0) {
 			maxTd = fp.confirmedTd
 			maxTd = fp.confirmedTd
@@ -705,7 +704,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) {
 		node = fp.lastAnnounced
 		node = fp.lastAnnounced
 		td   *big.Int
 		td   *big.Int
 	)
 	)
-	if f.pm.ulc != nil {
+	if f.handler.ulc != nil {
 		// Roll back untrusted blocks
 		// Roll back untrusted blocks
 		h, unapproved := f.lastTrustedTreeNode(p)
 		h, unapproved := f.lastTrustedTreeNode(p)
 		f.chain.Rollback(unapproved)
 		f.chain.Rollback(unapproved)
@@ -721,7 +720,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) {
 	// Now node is the latest downloaded/approved header after syncing
 	// Now node is the latest downloaded/approved header after syncing
 	if node == nil {
 	if node == nil {
 		p.Log().Debug("Synchronisation failed")
 		p.Log().Debug("Synchronisation failed")
-		go f.pm.removePeer(p.id)
+		go f.handler.removePeer(p.id)
 		return
 		return
 	}
 	}
 	header := f.chain.GetHeader(node.hash, node.number)
 	header := f.chain.GetHeader(node.hash, node.number)
@@ -741,7 +740,7 @@ func (f *lightFetcher) lastTrustedTreeNode(p *peer) (*types.Header, []common.Has
 	if canonical.Number.Uint64() > f.lastTrustedHeader.Number.Uint64() {
 	if canonical.Number.Uint64() > f.lastTrustedHeader.Number.Uint64() {
 		canonical = f.chain.GetHeaderByNumber(f.lastTrustedHeader.Number.Uint64())
 		canonical = f.chain.GetHeaderByNumber(f.lastTrustedHeader.Number.Uint64())
 	}
 	}
-	commonAncestor := rawdb.FindCommonAncestor(f.pm.chainDb, canonical, f.lastTrustedHeader)
+	commonAncestor := rawdb.FindCommonAncestor(f.handler.backend.chainDb, canonical, f.lastTrustedHeader)
 	if commonAncestor == nil {
 	if commonAncestor == nil {
 		log.Error("Common ancestor of last trusted header and canonical header is nil", "canonical hash", canonical.Hash(), "trusted hash", f.lastTrustedHeader.Hash())
 		log.Error("Common ancestor of last trusted header and canonical header is nil", "canonical hash", canonical.Hash(), "trusted hash", f.lastTrustedHeader.Hash())
 		return current, unapprovedHashes
 		return current, unapprovedHashes
@@ -787,7 +786,7 @@ func (f *lightFetcher) checkKnownNode(p *peer, n *fetcherTreeNode) bool {
 	}
 	}
 	if !f.checkAnnouncedHeaders(fp, []*types.Header{header}, []*big.Int{td}) {
 	if !f.checkAnnouncedHeaders(fp, []*types.Header{header}, []*big.Int{td}) {
 		p.Log().Debug("Inconsistent announcement")
 		p.Log().Debug("Inconsistent announcement")
-		go f.pm.removePeer(p.id)
+		go f.handler.removePeer(p.id)
 	}
 	}
 	if fp.confirmedTd != nil {
 	if fp.confirmedTd != nil {
 		f.updateMaxConfirmedTd(fp.confirmedTd)
 		f.updateMaxConfirmedTd(fp.confirmedTd)
@@ -880,12 +879,12 @@ func (f *lightFetcher) checkUpdateStats(p *peer, newEntry *updateStatsEntry) {
 		fp.firstUpdateStats = newEntry
 		fp.firstUpdateStats = newEntry
 	}
 	}
 	for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) {
 	for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) {
-		f.pm.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout)
+		f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout)
 		fp.firstUpdateStats = fp.firstUpdateStats.next
 		fp.firstUpdateStats = fp.firstUpdateStats.next
 	}
 	}
 	if fp.confirmedTd != nil {
 	if fp.confirmedTd != nil {
 		for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 {
 		for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 {
-			f.pm.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time))
+			f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time))
 			fp.firstUpdateStats = fp.firstUpdateStats.next
 			fp.firstUpdateStats = fp.firstUpdateStats.next
 		}
 		}
 	}
 	}

+ 0 - 168
les/fetcher_test.go

@@ -1,168 +0,0 @@
-// Copyright 2019 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package les
-
-import (
-	"math/big"
-	"testing"
-
-	"net"
-
-	"github.com/ethereum/go-ethereum/common"
-	"github.com/ethereum/go-ethereum/core/types"
-	"github.com/ethereum/go-ethereum/crypto"
-	"github.com/ethereum/go-ethereum/p2p"
-	"github.com/ethereum/go-ethereum/p2p/enode"
-)
-
-func TestFetcherULCPeerSelector(t *testing.T) {
-	id1 := newNodeID(t).ID()
-	id2 := newNodeID(t).ID()
-	id3 := newNodeID(t).ID()
-	id4 := newNodeID(t).ID()
-
-	ftn1 := &fetcherTreeNode{
-		hash: common.HexToHash("1"),
-		td:   big.NewInt(1),
-	}
-	ftn2 := &fetcherTreeNode{
-		hash:   common.HexToHash("2"),
-		td:     big.NewInt(2),
-		parent: ftn1,
-	}
-	ftn3 := &fetcherTreeNode{
-		hash:   common.HexToHash("3"),
-		td:     big.NewInt(3),
-		parent: ftn2,
-	}
-	lf := lightFetcher{
-		pm: &ProtocolManager{
-			ulc: &ulc{
-				keys: map[string]bool{
-					id1.String(): true,
-					id2.String(): true,
-					id3.String(): true,
-					id4.String(): true,
-				},
-				fraction: 70,
-			},
-		},
-		maxConfirmedTd: ftn1.td,
-
-		peers: map[*peer]*fetcherPeerInfo{
-			{
-				id:      "peer1",
-				Peer:    p2p.NewPeer(id1, "peer1", []p2p.Cap{}),
-				trusted: true,
-			}: {
-				nodeByHash: map[common.Hash]*fetcherTreeNode{
-					ftn1.hash: ftn1,
-					ftn2.hash: ftn2,
-				},
-			},
-			{
-				Peer:    p2p.NewPeer(id2, "peer2", []p2p.Cap{}),
-				id:      "peer2",
-				trusted: true,
-			}: {
-				nodeByHash: map[common.Hash]*fetcherTreeNode{
-					ftn1.hash: ftn1,
-					ftn2.hash: ftn2,
-				},
-			},
-			{
-				id:      "peer3",
-				Peer:    p2p.NewPeer(id3, "peer3", []p2p.Cap{}),
-				trusted: true,
-			}: {
-				nodeByHash: map[common.Hash]*fetcherTreeNode{
-					ftn1.hash: ftn1,
-					ftn2.hash: ftn2,
-					ftn3.hash: ftn3,
-				},
-			},
-			{
-				id:      "peer4",
-				Peer:    p2p.NewPeer(id4, "peer4", []p2p.Cap{}),
-				trusted: true,
-			}: {
-				nodeByHash: map[common.Hash]*fetcherTreeNode{
-					ftn1.hash: ftn1,
-				},
-			},
-		},
-		chain: &lightChainStub{
-			tds: map[common.Hash]*big.Int{},
-			headers: map[common.Hash]*types.Header{
-				ftn1.hash: {},
-				ftn2.hash: {},
-				ftn3.hash: {},
-			},
-		},
-	}
-	bestHash, bestAmount, bestTD, sync := lf.findBestRequest()
-
-	if bestTD == nil {
-		t.Fatal("Empty result")
-	}
-
-	if bestTD.Cmp(ftn2.td) != 0 {
-		t.Fatal("bad td", bestTD)
-	}
-	if bestHash != ftn2.hash {
-		t.Fatal("bad hash", bestTD)
-	}
-
-	_, _ = bestAmount, sync
-}
-
-type lightChainStub struct {
-	BlockChain
-	tds                         map[common.Hash]*big.Int
-	headers                     map[common.Hash]*types.Header
-	insertHeaderChainAssertFunc func(chain []*types.Header, checkFreq int) (int, error)
-}
-
-func (l *lightChainStub) GetHeader(hash common.Hash, number uint64) *types.Header {
-	if h, ok := l.headers[hash]; ok {
-		return h
-	}
-
-	return nil
-}
-
-func (l *lightChainStub) LockChain()   {}
-func (l *lightChainStub) UnlockChain() {}
-
-func (l *lightChainStub) GetTd(hash common.Hash, number uint64) *big.Int {
-	if td, ok := l.tds[hash]; ok {
-		return td
-	}
-	return nil
-}
-
-func (l *lightChainStub) InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) {
-	return l.insertHeaderChainAssertFunc(chain, checkFreq)
-}
-
-func newNodeID(t *testing.T) *enode.Node {
-	key, err := crypto.GenerateKey()
-	if err != nil {
-		t.Fatal("generate key err:", err)
-	}
-	return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000)
-}

+ 0 - 1293
les/handler.go

@@ -1,1293 +0,0 @@
-// Copyright 2016 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package les
-
-import (
-	"encoding/binary"
-	"encoding/json"
-	"errors"
-	"fmt"
-	"math/big"
-	"sync"
-	"sync/atomic"
-	"time"
-
-	"github.com/ethereum/go-ethereum/common"
-	"github.com/ethereum/go-ethereum/common/mclock"
-	"github.com/ethereum/go-ethereum/core"
-	"github.com/ethereum/go-ethereum/core/rawdb"
-	"github.com/ethereum/go-ethereum/core/state"
-	"github.com/ethereum/go-ethereum/core/types"
-	"github.com/ethereum/go-ethereum/eth/downloader"
-	"github.com/ethereum/go-ethereum/ethdb"
-	"github.com/ethereum/go-ethereum/event"
-	"github.com/ethereum/go-ethereum/light"
-	"github.com/ethereum/go-ethereum/log"
-	"github.com/ethereum/go-ethereum/p2p"
-	"github.com/ethereum/go-ethereum/p2p/discv5"
-	"github.com/ethereum/go-ethereum/params"
-	"github.com/ethereum/go-ethereum/rlp"
-	"github.com/ethereum/go-ethereum/trie"
-)
-
-var errTooManyInvalidRequest = errors.New("too many invalid requests made")
-
-const (
-	softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
-	estHeaderRlpSize  = 500             // Approximate size of an RLP encoded block header
-
-	ethVersion = 63 // equivalent eth version for the downloader
-
-	MaxHeaderFetch           = 192 // Amount of block headers to be fetched per retrieval request
-	MaxBodyFetch             = 32  // Amount of block bodies to be fetched per retrieval request
-	MaxReceiptFetch          = 128 // Amount of transaction receipts to allow fetching per request
-	MaxCodeFetch             = 64  // Amount of contract codes to allow fetching per request
-	MaxProofsFetch           = 64  // Amount of merkle proofs to be fetched per retrieval request
-	MaxHelperTrieProofsFetch = 64  // Amount of merkle proofs to be fetched per retrieval request
-	MaxTxSend                = 64  // Amount of transactions to be send per request
-	MaxTxStatus              = 256 // Amount of transactions to queried per request
-
-	disableClientRemovePeer = false
-)
-
-func errResp(code errCode, format string, v ...interface{}) error {
-	return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...))
-}
-
-type BlockChain interface {
-	Config() *params.ChainConfig
-	HasHeader(hash common.Hash, number uint64) bool
-	GetHeader(hash common.Hash, number uint64) *types.Header
-	GetHeaderByHash(hash common.Hash) *types.Header
-	CurrentHeader() *types.Header
-	GetTd(hash common.Hash, number uint64) *big.Int
-	StateCache() state.Database
-	InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error)
-	Rollback(chain []common.Hash)
-	GetHeaderByNumber(number uint64) *types.Header
-	GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64)
-	Genesis() *types.Block
-	SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription
-}
-
-type txPool interface {
-	AddRemotes(txs []*types.Transaction) []error
-	AddRemotesSync(txs []*types.Transaction) []error
-	Status(hashes []common.Hash) []core.TxStatus
-}
-
-type ProtocolManager struct {
-	// Configs
-	chainConfig *params.ChainConfig
-	iConfig     *light.IndexerConfig
-
-	client    bool   // The indicator whether the node is light client
-	maxPeers  int    // The maximum number peers allowed to connect.
-	networkId uint64 // The identity of network.
-
-	txpool       txPool
-	txrelay      *lesTxRelay
-	blockchain   BlockChain
-	chainDb      ethdb.Database
-	odr          *LesOdr
-	server       *LesServer
-	serverPool   *serverPool
-	lesTopic     discv5.Topic
-	reqDist      *requestDistributor
-	retriever    *retrieveManager
-	servingQueue *servingQueue
-	downloader   *downloader.Downloader
-	fetcher      *lightFetcher
-	ulc          *ulc
-	peers        *peerSet
-	checkpoint   *params.TrustedCheckpoint
-	reg          *checkpointOracle // If reg == nil, it means the checkpoint registrar is not activated
-
-	// channels for fetcher, syncer, txsyncLoop
-	newPeerCh   chan *peer
-	quitSync    chan struct{}
-	noMorePeers chan struct{}
-
-	wg       *sync.WaitGroup
-	eventMux *event.TypeMux
-
-	// Callbacks
-	synced func() bool
-
-	// Testing fields
-	addTxsSync bool
-}
-
-// NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable
-// with the ethereum network.
-func NewProtocolManager(chainConfig *params.ChainConfig, checkpoint *params.TrustedCheckpoint, indexerConfig *light.IndexerConfig, ulcServers []string, ulcFraction int, client bool, networkId uint64, mux *event.TypeMux, peers *peerSet, blockchain BlockChain, txpool txPool, chainDb ethdb.Database, odr *LesOdr, serverPool *serverPool, registrar *checkpointOracle, quitSync chan struct{}, wg *sync.WaitGroup, synced func() bool) (*ProtocolManager, error) {
-	// Create the protocol manager with the base fields
-	manager := &ProtocolManager{
-		client:      client,
-		eventMux:    mux,
-		blockchain:  blockchain,
-		chainConfig: chainConfig,
-		iConfig:     indexerConfig,
-		chainDb:     chainDb,
-		odr:         odr,
-		networkId:   networkId,
-		txpool:      txpool,
-		serverPool:  serverPool,
-		reg:         registrar,
-		peers:       peers,
-		newPeerCh:   make(chan *peer),
-		quitSync:    quitSync,
-		wg:          wg,
-		noMorePeers: make(chan struct{}),
-		checkpoint:  checkpoint,
-		synced:      synced,
-	}
-	if odr != nil {
-		manager.retriever = odr.retriever
-		manager.reqDist = odr.retriever.dist
-	}
-
-	if ulcServers != nil {
-		ulc, err := newULC(ulcServers, ulcFraction)
-		if err != nil {
-			log.Warn("Failed to initialize ultra light client", "err", err)
-		} else {
-			manager.ulc = ulc
-		}
-	}
-	removePeer := manager.removePeer
-	if disableClientRemovePeer {
-		removePeer = func(id string) {}
-	}
-	if client {
-		var checkpointNumber uint64
-		if checkpoint != nil {
-			checkpointNumber = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1
-		}
-		manager.downloader = downloader.New(checkpointNumber, chainDb, nil, manager.eventMux, nil, blockchain, removePeer)
-		manager.peers.notify((*downloaderPeerNotify)(manager))
-		manager.fetcher = newLightFetcher(manager)
-	}
-	return manager, nil
-}
-
-// removePeer initiates disconnection from a peer by removing it from the peer set
-func (pm *ProtocolManager) removePeer(id string) {
-	pm.peers.Unregister(id)
-}
-
-func (pm *ProtocolManager) Start(maxPeers int) {
-	pm.maxPeers = maxPeers
-	if pm.client {
-		go pm.syncer()
-	} else {
-		go func() {
-			for range pm.newPeerCh {
-			}
-		}()
-	}
-}
-
-func (pm *ProtocolManager) Stop() {
-	// Showing a log message. During download / process this could actually
-	// take between 5 to 10 seconds and therefor feedback is required.
-	log.Info("Stopping light Ethereum protocol")
-
-	// Quit the sync loop.
-	// After this send has completed, no new peers will be accepted.
-	pm.noMorePeers <- struct{}{}
-
-	close(pm.quitSync) // quits syncer, fetcher
-
-	if pm.servingQueue != nil {
-		pm.servingQueue.stop()
-	}
-
-	// Disconnect existing sessions.
-	// This also closes the gate for any new registrations on the peer set.
-	// sessions which are already established but not added to pm.peers yet
-	// will exit when they try to register.
-	pm.peers.Close()
-
-	// Wait for any process action
-	pm.wg.Wait()
-
-	log.Info("Light Ethereum protocol stopped")
-}
-
-// runPeer is the p2p protocol run function for the given version.
-func (pm *ProtocolManager) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
-	var entry *poolEntry
-	peer := pm.newPeer(int(version), pm.networkId, p, rw)
-	if pm.serverPool != nil {
-		entry = pm.serverPool.connect(peer, peer.Node())
-	}
-	peer.poolEntry = entry
-	select {
-	case pm.newPeerCh <- peer:
-		pm.wg.Add(1)
-		defer pm.wg.Done()
-		err := pm.handle(peer)
-		if entry != nil {
-			pm.serverPool.disconnect(entry)
-		}
-		return err
-	case <-pm.quitSync:
-		if entry != nil {
-			pm.serverPool.disconnect(entry)
-		}
-		return p2p.DiscQuitting
-	}
-}
-
-func (pm *ProtocolManager) newPeer(pv int, nv uint64, p *p2p.Peer, rw p2p.MsgReadWriter) *peer {
-	var trusted bool
-	if pm.ulc != nil {
-		trusted = pm.ulc.trusted(p.ID())
-	}
-	return newPeer(pv, nv, trusted, p, newMeteredMsgWriter(rw))
-}
-
-// handle is the callback invoked to manage the life cycle of a les peer. When
-// this function terminates, the peer is disconnected.
-func (pm *ProtocolManager) handle(p *peer) error {
-	// Ignore maxPeers if this is a trusted peer
-	// In server mode we try to check into the client pool after handshake
-	if pm.client && pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted {
-		clientRejectedMeter.Mark(1)
-		return p2p.DiscTooManyPeers
-	}
-	// Reject light clients if server is not synced.
-	if !pm.client && !pm.synced() {
-		clientRejectedMeter.Mark(1)
-		return p2p.DiscRequested
-	}
-	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
-
-	// Execute the LES handshake
-	var (
-		genesis = pm.blockchain.Genesis()
-		head    = pm.blockchain.CurrentHeader()
-		hash    = head.Hash()
-		number  = head.Number.Uint64()
-		td      = pm.blockchain.GetTd(hash, number)
-	)
-	if err := p.Handshake(td, hash, number, genesis.Hash(), pm.server); err != nil {
-		p.Log().Debug("Light Ethereum handshake failed", "err", err)
-		clientErrorMeter.Mark(1)
-		return err
-	}
-	if p.fcClient != nil {
-		defer p.fcClient.Disconnect()
-	}
-
-	if rw, ok := p.rw.(*meteredMsgReadWriter); ok {
-		rw.Init(p.version)
-	}
-
-	// Register the peer locally
-	if err := pm.peers.Register(p); err != nil {
-		clientErrorMeter.Mark(1)
-		p.Log().Error("Light Ethereum peer registration failed", "err", err)
-		return err
-	}
-	if !pm.client && p.balanceTracker == nil {
-		// add dummy balance tracker for tests
-		p.balanceTracker = &balanceTracker{}
-		p.balanceTracker.init(&mclock.System{}, 1)
-	}
-	connectedAt := time.Now()
-	defer func() {
-		p.balanceTracker = nil
-		pm.removePeer(p.id)
-		connectionTimer.UpdateSince(connectedAt)
-	}()
-
-	// Register the peer in the downloader. If the downloader considers it banned, we disconnect
-	if pm.client {
-		p.lock.Lock()
-		head := p.headInfo
-		p.lock.Unlock()
-		if pm.fetcher != nil {
-			pm.fetcher.announce(p, head)
-		}
-
-		if p.poolEntry != nil {
-			pm.serverPool.registered(p.poolEntry)
-		}
-	}
-	// main loop. handle incoming messages.
-	for {
-		if err := pm.handleMsg(p); err != nil {
-			p.Log().Debug("Light Ethereum message handling failed", "err", err)
-			if p.fcServer != nil {
-				p.fcServer.DumpLogs()
-			}
-			return err
-		}
-	}
-}
-
-// handleMsg is invoked whenever an inbound message is received from a remote
-// peer. The remote connection is torn down upon returning any error.
-func (pm *ProtocolManager) handleMsg(p *peer) error {
-	select {
-	case err := <-p.errCh:
-		return err
-	default:
-	}
-	// Read the next message from the remote peer, and ensure it's fully consumed
-	msg, err := p.rw.ReadMsg()
-	if err != nil {
-		return err
-	}
-	p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
-
-	p.responseCount++
-	responseCount := p.responseCount
-	var (
-		maxCost uint64
-		task    *servingTask
-	)
-
-	accept := func(reqID, reqCnt, maxCnt uint64) bool {
-		inSizeCost := func() uint64 {
-			if pm.server.costTracker != nil {
-				return pm.server.costTracker.realCost(0, msg.Size, 0)
-			}
-			return 0
-		}
-		if p.isFrozen() || reqCnt == 0 || p.fcClient == nil || reqCnt > maxCnt {
-			p.fcClient.OneTimeCost(inSizeCost())
-			return false
-		}
-		maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
-		gf := float64(1)
-		if pm.server.costTracker != nil {
-			gf = pm.server.costTracker.globalFactor()
-			if gf < 0.001 {
-				p.Log().Error("Invalid global cost factor", "globalFactor", gf)
-				gf = 1
-			}
-		}
-		maxTime := uint64(float64(maxCost) / gf)
-
-		if accepted, bufShort, servingPriority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost); !accepted {
-			p.freezeClient()
-			p.Log().Warn("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
-			p.fcClient.OneTimeCost(inSizeCost())
-			return false
-		} else {
-			task = pm.servingQueue.newTask(p, maxTime, servingPriority)
-		}
-		if task.start() {
-			return true
-		}
-		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost())
-		return false
-	}
-
-	if msg.Size > ProtocolMaxMsgSize {
-		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
-	}
-	defer msg.Discard()
-
-	var deliverMsg *Msg
-	balanceTracker := p.balanceTracker
-
-	sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
-		p.responseLock.Lock()
-		defer p.responseLock.Unlock()
-
-		if p.isFrozen() {
-			amount = 0
-			reply = nil
-		}
-		var replySize uint32
-		if reply != nil {
-			replySize = reply.size()
-		}
-		var realCost uint64
-		if pm.server.costTracker != nil {
-			realCost = pm.server.costTracker.realCost(servingTime, msg.Size, replySize)
-			if amount != 0 {
-				pm.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
-				balanceTracker.requestCost(realCost)
-			}
-		} else {
-			realCost = maxCost
-		}
-		bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
-		if reply != nil {
-			p.queueSend(func() {
-				if err := reply.send(bv); err != nil {
-					select {
-					case p.errCh <- err:
-					default:
-					}
-				}
-			})
-		}
-	}
-
-	// Handle the message depending on its contents
-	switch msg.Code {
-	case StatusMsg:
-		p.Log().Trace("Received status message")
-		// Status messages should never arrive after the handshake
-		return errResp(ErrExtraStatusMsg, "uncontrolled status message")
-
-	// Block header query, collect the requested headers and reply
-	case AnnounceMsg:
-		p.Log().Trace("Received announce message")
-		var req announceData
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "%v: %v", msg, err)
-		}
-		if err := req.sanityCheck(); err != nil {
-			return err
-		}
-		update, size := req.Update.decode()
-		if p.rejectUpdate(size) {
-			return errResp(ErrRequestRejected, "")
-		}
-		p.updateFlowControl(update)
-
-		if req.Hash != (common.Hash{}) {
-			if p.announceType == announceTypeNone {
-				return errResp(ErrUnexpectedResponse, "")
-			}
-			if p.announceType == announceTypeSigned {
-				if err := req.checkSignature(p.ID(), update); err != nil {
-					p.Log().Trace("Invalid announcement signature", "err", err)
-					return err
-				}
-				p.Log().Trace("Valid announcement signature")
-			}
-
-			p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth)
-			if pm.fetcher != nil {
-				pm.fetcher.announce(p, &req)
-			}
-		}
-
-	case GetBlockHeadersMsg:
-		p.Log().Trace("Received block header request")
-		// Decode the complex header query
-		var req struct {
-			ReqID uint64
-			Query getBlockHeadersData
-		}
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "%v: %v", msg, err)
-		}
-
-		query := req.Query
-		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
-			go func() {
-				hashMode := query.Origin.Hash != (common.Hash{})
-				first := true
-				maxNonCanonical := uint64(100)
-
-				// Gather headers until the fetch or network limits is reached
-				var (
-					bytes   common.StorageSize
-					headers []*types.Header
-					unknown bool
-				)
-				for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
-					if !first && !task.waitOrStop() {
-						sendResponse(req.ReqID, 0, nil, task.servingTime)
-						return
-					}
-					// Retrieve the next header satisfying the query
-					var origin *types.Header
-					if hashMode {
-						if first {
-							origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
-							if origin != nil {
-								query.Origin.Number = origin.Number.Uint64()
-							}
-						} else {
-							origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
-						}
-					} else {
-						origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number)
-					}
-					if origin == nil {
-						atomic.AddUint32(&p.invalidCount, 1)
-						break
-					}
-					headers = append(headers, origin)
-					bytes += estHeaderRlpSize
-
-					// Advance to the next header of the query
-					switch {
-					case hashMode && query.Reverse:
-						// Hash based traversal towards the genesis block
-						ancestor := query.Skip + 1
-						if ancestor == 0 {
-							unknown = true
-						} else {
-							query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
-							unknown = (query.Origin.Hash == common.Hash{})
-						}
-					case hashMode && !query.Reverse:
-						// Hash based traversal towards the leaf block
-						var (
-							current = origin.Number.Uint64()
-							next    = current + query.Skip + 1
-						)
-						if next <= current {
-							infos, _ := json.MarshalIndent(p.Peer.Info(), "", "  ")
-							p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
-							unknown = true
-						} else {
-							if header := pm.blockchain.GetHeaderByNumber(next); header != nil {
-								nextHash := header.Hash()
-								expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
-								if expOldHash == query.Origin.Hash {
-									query.Origin.Hash, query.Origin.Number = nextHash, next
-								} else {
-									unknown = true
-								}
-							} else {
-								unknown = true
-							}
-						}
-					case query.Reverse:
-						// Number based traversal towards the genesis block
-						if query.Origin.Number >= query.Skip+1 {
-							query.Origin.Number -= query.Skip + 1
-						} else {
-							unknown = true
-						}
-					case !query.Reverse:
-						// Number based traversal towards the leaf block
-						query.Origin.Number += query.Skip + 1
-					}
-					first = false
-				}
-				sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done())
-			}()
-		}
-
-	case BlockHeadersMsg:
-		if pm.downloader == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-
-		p.Log().Trace("Received block header response message")
-		// A batch of headers arrived to one of our previous requests
-		var resp struct {
-			ReqID, BV uint64
-			Headers   []*types.Header
-		}
-		if err := msg.Decode(&resp); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
-		if pm.fetcher != nil && pm.fetcher.requestedID(resp.ReqID) {
-			pm.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers)
-		} else {
-			err := pm.downloader.DeliverHeaders(p.id, resp.Headers)
-			if err != nil {
-				log.Debug(fmt.Sprint(err))
-			}
-		}
-
-	case GetBlockBodiesMsg:
-		p.Log().Trace("Received block bodies request")
-		// Decode the retrieval message
-		var req struct {
-			ReqID  uint64
-			Hashes []common.Hash
-		}
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		// Gather blocks until the fetch or network limits is reached
-		var (
-			bytes  int
-			bodies []rlp.RawValue
-		)
-		reqCnt := len(req.Hashes)
-		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
-			go func() {
-				for i, hash := range req.Hashes {
-					if i != 0 && !task.waitOrStop() {
-						sendResponse(req.ReqID, 0, nil, task.servingTime)
-						return
-					}
-					// Retrieve the requested block body, stopping if enough was found
-					if bytes >= softResponseLimit {
-						break
-					}
-					number := rawdb.ReadHeaderNumber(pm.chainDb, hash)
-					if number == nil {
-						atomic.AddUint32(&p.invalidCount, 1)
-						continue
-					}
-					if data := rawdb.ReadBodyRLP(pm.chainDb, hash, *number); len(data) != 0 {
-						bodies = append(bodies, data)
-						bytes += len(data)
-					}
-				}
-				sendResponse(req.ReqID, uint64(reqCnt), p.ReplyBlockBodiesRLP(req.ReqID, bodies), task.done())
-			}()
-		}
-
-	case BlockBodiesMsg:
-		if pm.odr == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-
-		p.Log().Trace("Received block bodies response")
-		// A batch of block bodies arrived to one of our previous requests
-		var resp struct {
-			ReqID, BV uint64
-			Data      []*types.Body
-		}
-		if err := msg.Decode(&resp); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
-		deliverMsg = &Msg{
-			MsgType: MsgBlockBodies,
-			ReqID:   resp.ReqID,
-			Obj:     resp.Data,
-		}
-
-	case GetCodeMsg:
-		p.Log().Trace("Received code request")
-		// Decode the retrieval message
-		var req struct {
-			ReqID uint64
-			Reqs  []CodeReq
-		}
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		// Gather state data until the fetch or network limits is reached
-		var (
-			bytes int
-			data  [][]byte
-		)
-		reqCnt := len(req.Reqs)
-		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
-			go func() {
-				for i, request := range req.Reqs {
-					if i != 0 && !task.waitOrStop() {
-						sendResponse(req.ReqID, 0, nil, task.servingTime)
-						return
-					}
-					// Look up the root hash belonging to the request
-					number := rawdb.ReadHeaderNumber(pm.chainDb, request.BHash)
-					if number == nil {
-						p.Log().Warn("Failed to retrieve block num for code", "hash", request.BHash)
-						atomic.AddUint32(&p.invalidCount, 1)
-						continue
-					}
-					header := rawdb.ReadHeader(pm.chainDb, request.BHash, *number)
-					if header == nil {
-						p.Log().Warn("Failed to retrieve header for code", "block", *number, "hash", request.BHash)
-						continue
-					}
-					// Refuse to search stale state data in the database since looking for
-					// a non-exist key is kind of expensive.
-					local := pm.blockchain.CurrentHeader().Number.Uint64()
-					if !pm.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
-						p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
-						atomic.AddUint32(&p.invalidCount, 1)
-						continue
-					}
-					triedb := pm.blockchain.StateCache().TrieDB()
-
-					account, err := pm.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
-					if err != nil {
-						p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
-						atomic.AddUint32(&p.invalidCount, 1)
-						continue
-					}
-					code, err := triedb.Node(common.BytesToHash(account.CodeHash))
-					if err != nil {
-						p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
-						continue
-					}
-					// Accumulate the code and abort if enough data was retrieved
-					data = append(data, code)
-					if bytes += len(code); bytes >= softResponseLimit {
-						break
-					}
-				}
-				sendResponse(req.ReqID, uint64(reqCnt), p.ReplyCode(req.ReqID, data), task.done())
-			}()
-		}
-
-	case CodeMsg:
-		if pm.odr == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-
-		p.Log().Trace("Received code response")
-		// A batch of node state data arrived to one of our previous requests
-		var resp struct {
-			ReqID, BV uint64
-			Data      [][]byte
-		}
-		if err := msg.Decode(&resp); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
-		deliverMsg = &Msg{
-			MsgType: MsgCode,
-			ReqID:   resp.ReqID,
-			Obj:     resp.Data,
-		}
-
-	case GetReceiptsMsg:
-		p.Log().Trace("Received receipts request")
-		// Decode the retrieval message
-		var req struct {
-			ReqID  uint64
-			Hashes []common.Hash
-		}
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		// Gather state data until the fetch or network limits is reached
-		var (
-			bytes    int
-			receipts []rlp.RawValue
-		)
-		reqCnt := len(req.Hashes)
-		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
-			go func() {
-				for i, hash := range req.Hashes {
-					if i != 0 && !task.waitOrStop() {
-						sendResponse(req.ReqID, 0, nil, task.servingTime)
-						return
-					}
-					if bytes >= softResponseLimit {
-						break
-					}
-					// Retrieve the requested block's receipts, skipping if unknown to us
-					var results types.Receipts
-					number := rawdb.ReadHeaderNumber(pm.chainDb, hash)
-					if number == nil {
-						atomic.AddUint32(&p.invalidCount, 1)
-						continue
-					}
-					results = rawdb.ReadRawReceipts(pm.chainDb, hash, *number)
-					if results == nil {
-						if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
-							continue
-						}
-					}
-					// If known, encode and queue for response packet
-					if encoded, err := rlp.EncodeToBytes(results); err != nil {
-						log.Error("Failed to encode receipt", "err", err)
-					} else {
-						receipts = append(receipts, encoded)
-						bytes += len(encoded)
-					}
-				}
-				sendResponse(req.ReqID, uint64(reqCnt), p.ReplyReceiptsRLP(req.ReqID, receipts), task.done())
-			}()
-		}
-
-	case ReceiptsMsg:
-		if pm.odr == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-
-		p.Log().Trace("Received receipts response")
-		// A batch of receipts arrived to one of our previous requests
-		var resp struct {
-			ReqID, BV uint64
-			Receipts  []types.Receipts
-		}
-		if err := msg.Decode(&resp); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
-		deliverMsg = &Msg{
-			MsgType: MsgReceipts,
-			ReqID:   resp.ReqID,
-			Obj:     resp.Receipts,
-		}
-
-	case GetProofsV2Msg:
-		p.Log().Trace("Received les/2 proofs request")
-		// Decode the retrieval message
-		var req struct {
-			ReqID uint64
-			Reqs  []ProofReq
-		}
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		// Gather state data until the fetch or network limits is reached
-		var (
-			lastBHash common.Hash
-			root      common.Hash
-		)
-		reqCnt := len(req.Reqs)
-		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
-			go func() {
-				nodes := light.NewNodeSet()
-
-				for i, request := range req.Reqs {
-					if i != 0 && !task.waitOrStop() {
-						sendResponse(req.ReqID, 0, nil, task.servingTime)
-						return
-					}
-					// Look up the root hash belonging to the request
-					var (
-						number *uint64
-						header *types.Header
-						trie   state.Trie
-					)
-					if request.BHash != lastBHash {
-						root, lastBHash = common.Hash{}, request.BHash
-
-						if number = rawdb.ReadHeaderNumber(pm.chainDb, request.BHash); number == nil {
-							p.Log().Warn("Failed to retrieve block num for proof", "hash", request.BHash)
-							atomic.AddUint32(&p.invalidCount, 1)
-							continue
-						}
-						if header = rawdb.ReadHeader(pm.chainDb, request.BHash, *number); header == nil {
-							p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", request.BHash)
-							continue
-						}
-						// Refuse to search stale state data in the database since looking for
-						// a non-exist key is kind of expensive.
-						local := pm.blockchain.CurrentHeader().Number.Uint64()
-						if !pm.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
-							p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
-							atomic.AddUint32(&p.invalidCount, 1)
-							continue
-						}
-						root = header.Root
-					}
-					// If a header lookup failed (non existent), ignore subsequent requests for the same header
-					if root == (common.Hash{}) {
-						atomic.AddUint32(&p.invalidCount, 1)
-						continue
-					}
-					// Open the account or storage trie for the request
-					statedb := pm.blockchain.StateCache()
-
-					switch len(request.AccKey) {
-					case 0:
-						// No account key specified, open an account trie
-						trie, err = statedb.OpenTrie(root)
-						if trie == nil || err != nil {
-							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
-							continue
-						}
-					default:
-						// Account key specified, open a storage trie
-						account, err := pm.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
-						if err != nil {
-							p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
-							atomic.AddUint32(&p.invalidCount, 1)
-							continue
-						}
-						trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
-						if trie == nil || err != nil {
-							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
-							continue
-						}
-					}
-					// Prove the user's request from the account or stroage trie
-					if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
-						p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
-						continue
-					}
-					if nodes.DataSize() >= softResponseLimit {
-						break
-					}
-				}
-				sendResponse(req.ReqID, uint64(reqCnt), p.ReplyProofsV2(req.ReqID, nodes.NodeList()), task.done())
-			}()
-		}
-
-	case ProofsV2Msg:
-		if pm.odr == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-
-		p.Log().Trace("Received les/2 proofs response")
-		// A batch of merkle proofs arrived to one of our previous requests
-		var resp struct {
-			ReqID, BV uint64
-			Data      light.NodeList
-		}
-		if err := msg.Decode(&resp); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
-		deliverMsg = &Msg{
-			MsgType: MsgProofsV2,
-			ReqID:   resp.ReqID,
-			Obj:     resp.Data,
-		}
-
-	case GetHelperTrieProofsMsg:
-		p.Log().Trace("Received helper trie proof request")
-		// Decode the retrieval message
-		var req struct {
-			ReqID uint64
-			Reqs  []HelperTrieReq
-		}
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		// Gather state data until the fetch or network limits is reached
-		var (
-			auxBytes int
-			auxData  [][]byte
-		)
-		reqCnt := len(req.Reqs)
-		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
-			go func() {
-
-				var (
-					lastIdx  uint64
-					lastType uint
-					root     common.Hash
-					auxTrie  *trie.Trie
-				)
-				nodes := light.NewNodeSet()
-				for i, request := range req.Reqs {
-					if i != 0 && !task.waitOrStop() {
-						sendResponse(req.ReqID, 0, nil, task.servingTime)
-						return
-					}
-					if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
-						auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
-
-						var prefix string
-						if root, prefix = pm.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
-							auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(pm.chainDb, prefix)))
-						}
-					}
-					if request.AuxReq == auxRoot {
-						var data []byte
-						if root != (common.Hash{}) {
-							data = root[:]
-						}
-						auxData = append(auxData, data)
-						auxBytes += len(data)
-					} else {
-						if auxTrie != nil {
-							auxTrie.Prove(request.Key, request.FromLevel, nodes)
-						}
-						if request.AuxReq != 0 {
-							data := pm.getHelperTrieAuxData(request)
-							auxData = append(auxData, data)
-							auxBytes += len(data)
-						}
-					}
-					if nodes.DataSize()+auxBytes >= softResponseLimit {
-						break
-					}
-				}
-				sendResponse(req.ReqID, uint64(reqCnt), p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}), task.done())
-			}()
-		}
-
-	case HelperTrieProofsMsg:
-		if pm.odr == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-
-		p.Log().Trace("Received helper trie proof response")
-		var resp struct {
-			ReqID, BV uint64
-			Data      HelperTrieResps
-		}
-		if err := msg.Decode(&resp); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-
-		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
-		deliverMsg = &Msg{
-			MsgType: MsgHelperTrieProofs,
-			ReqID:   resp.ReqID,
-			Obj:     resp.Data,
-		}
-
-	case SendTxV2Msg:
-		if pm.txpool == nil {
-			return errResp(ErrRequestRejected, "")
-		}
-		// Transactions arrived, parse all of them and deliver to the pool
-		var req struct {
-			ReqID uint64
-			Txs   []*types.Transaction
-		}
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		reqCnt := len(req.Txs)
-		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
-			go func() {
-				stats := make([]light.TxStatus, len(req.Txs))
-				for i, tx := range req.Txs {
-					if i != 0 && !task.waitOrStop() {
-						sendResponse(req.ReqID, 0, nil, task.servingTime)
-						return
-					}
-					hash := tx.Hash()
-					stats[i] = pm.txStatus(hash)
-					if stats[i].Status == core.TxStatusUnknown {
-						addFn := pm.txpool.AddRemotes
-						// Add txs synchronously for testing purpose
-						if pm.addTxsSync {
-							addFn = pm.txpool.AddRemotesSync
-						}
-						if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
-							stats[i].Error = errs[0].Error()
-							continue
-						}
-						stats[i] = pm.txStatus(hash)
-					}
-				}
-				sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done())
-			}()
-		}
-
-	case GetTxStatusMsg:
-		if pm.txpool == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-		// Transactions arrived, parse all of them and deliver to the pool
-		var req struct {
-			ReqID  uint64
-			Hashes []common.Hash
-		}
-		if err := msg.Decode(&req); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		reqCnt := len(req.Hashes)
-		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
-			go func() {
-				stats := make([]light.TxStatus, len(req.Hashes))
-				for i, hash := range req.Hashes {
-					if i != 0 && !task.waitOrStop() {
-						sendResponse(req.ReqID, 0, nil, task.servingTime)
-						return
-					}
-					stats[i] = pm.txStatus(hash)
-				}
-				sendResponse(req.ReqID, uint64(reqCnt), p.ReplyTxStatus(req.ReqID, stats), task.done())
-			}()
-		}
-
-	case TxStatusMsg:
-		if pm.odr == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-
-		p.Log().Trace("Received tx status response")
-		var resp struct {
-			ReqID, BV uint64
-			Status    []light.TxStatus
-		}
-		if err := msg.Decode(&resp); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-
-		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
-
-		p.Log().Trace("Received helper trie proof response")
-		deliverMsg = &Msg{
-			MsgType: MsgTxStatus,
-			ReqID:   resp.ReqID,
-			Obj:     resp.Status,
-		}
-
-	case StopMsg:
-		if pm.odr == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-		p.freezeServer(true)
-		pm.retriever.frozen(p)
-		p.Log().Debug("Service stopped")
-
-	case ResumeMsg:
-		if pm.odr == nil {
-			return errResp(ErrUnexpectedResponse, "")
-		}
-		var bv uint64
-		if err := msg.Decode(&bv); err != nil {
-			return errResp(ErrDecode, "msg %v: %v", msg, err)
-		}
-		p.fcServer.ResumeFreeze(bv)
-		p.freezeServer(false)
-		p.Log().Debug("Service resumed")
-
-	default:
-		p.Log().Trace("Received unknown message", "code", msg.Code)
-		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
-	}
-
-	if deliverMsg != nil {
-		err := pm.retriever.deliver(p, deliverMsg)
-		if err != nil {
-			p.responseErrors++
-			if p.responseErrors > maxResponseErrors {
-				return err
-			}
-		}
-	}
-	// If the client has made too much invalid request(e.g. request a non-exist data),
-	// reject them to prevent SPAM attack.
-	if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors {
-		return errTooManyInvalidRequest
-	}
-	return nil
-}
-
-// getAccount retrieves an account from the state based at root.
-func (pm *ProtocolManager) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
-	trie, err := trie.New(root, triedb)
-	if err != nil {
-		return state.Account{}, err
-	}
-	blob, err := trie.TryGet(hash[:])
-	if err != nil {
-		return state.Account{}, err
-	}
-	var account state.Account
-	if err = rlp.DecodeBytes(blob, &account); err != nil {
-		return state.Account{}, err
-	}
-	return account, nil
-}
-
-// getHelperTrie returns the post-processed trie root for the given trie ID and section index
-func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) {
-	switch id {
-	case htCanonical:
-		sectionHead := rawdb.ReadCanonicalHash(pm.chainDb, (idx+1)*pm.iConfig.ChtSize-1)
-		return light.GetChtRoot(pm.chainDb, idx, sectionHead), light.ChtTablePrefix
-	case htBloomBits:
-		sectionHead := rawdb.ReadCanonicalHash(pm.chainDb, (idx+1)*pm.iConfig.BloomTrieSize-1)
-		return light.GetBloomTrieRoot(pm.chainDb, idx, sectionHead), light.BloomTrieTablePrefix
-	}
-	return common.Hash{}, ""
-}
-
-// getHelperTrieAuxData returns requested auxiliary data for the given HelperTrie request
-func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte {
-	if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 {
-		blockNum := binary.BigEndian.Uint64(req.Key)
-		hash := rawdb.ReadCanonicalHash(pm.chainDb, blockNum)
-		return rawdb.ReadHeaderRLP(pm.chainDb, hash, blockNum)
-	}
-	return nil
-}
-
-func (pm *ProtocolManager) txStatus(hash common.Hash) light.TxStatus {
-	var stat light.TxStatus
-	stat.Status = pm.txpool.Status([]common.Hash{hash})[0]
-	// If the transaction is unknown to the pool, try looking it up locally
-	if stat.Status == core.TxStatusUnknown {
-		if tx, blockHash, blockNumber, txIndex := rawdb.ReadTransaction(pm.chainDb, hash); tx != nil {
-			stat.Status = core.TxStatusIncluded
-			stat.Lookup = &rawdb.LegacyTxLookupEntry{BlockHash: blockHash, BlockIndex: blockNumber, Index: txIndex}
-		}
-	}
-	return stat
-}
-
-// downloaderPeerNotify implements peerSetNotify
-type downloaderPeerNotify ProtocolManager
-
-type peerConnection struct {
-	manager *ProtocolManager
-	peer    *peer
-}
-
-func (pc *peerConnection) Head() (common.Hash, *big.Int) {
-	return pc.peer.HeadAndTd()
-}
-
-func (pc *peerConnection) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
-	reqID := genReqID()
-	rq := &distReq{
-		getCost: func(dp distPeer) uint64 {
-			peer := dp.(*peer)
-			return peer.GetRequestCost(GetBlockHeadersMsg, amount)
-		},
-		canSend: func(dp distPeer) bool {
-			return dp.(*peer) == pc.peer
-		},
-		request: func(dp distPeer) func() {
-			peer := dp.(*peer)
-			cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
-			peer.fcServer.QueuedRequest(reqID, cost)
-			return func() { peer.RequestHeadersByHash(reqID, cost, origin, amount, skip, reverse) }
-		},
-	}
-	_, ok := <-pc.manager.reqDist.queue(rq)
-	if !ok {
-		return light.ErrNoPeers
-	}
-	return nil
-}
-
-func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
-	reqID := genReqID()
-	rq := &distReq{
-		getCost: func(dp distPeer) uint64 {
-			peer := dp.(*peer)
-			return peer.GetRequestCost(GetBlockHeadersMsg, amount)
-		},
-		canSend: func(dp distPeer) bool {
-			return dp.(*peer) == pc.peer
-		},
-		request: func(dp distPeer) func() {
-			peer := dp.(*peer)
-			cost := peer.GetRequestCost(GetBlockHeadersMsg, amount)
-			peer.fcServer.QueuedRequest(reqID, cost)
-			return func() { peer.RequestHeadersByNumber(reqID, cost, origin, amount, skip, reverse) }
-		},
-	}
-	_, ok := <-pc.manager.reqDist.queue(rq)
-	if !ok {
-		return light.ErrNoPeers
-	}
-	return nil
-}
-
-func (d *downloaderPeerNotify) registerPeer(p *peer) {
-	pm := (*ProtocolManager)(d)
-	pc := &peerConnection{
-		manager: pm,
-		peer:    p,
-	}
-	pm.downloader.RegisterLightPeer(p.id, ethVersion, pc)
-}
-
-func (d *downloaderPeerNotify) unregisterPeer(p *peer) {
-	pm := (*ProtocolManager)(d)
-	pm.downloader.UnregisterPeer(p.id)
-}

+ 108 - 90
les/handler_test.go

@@ -48,11 +48,13 @@ func expectResponse(r p2p.MsgReader, msgcode, reqID, bv uint64, data interface{}
 
 
 // Tests that block headers can be retrieved from a remote chain based on user queries.
 // Tests that block headers can be retrieved from a remote chain based on user queries.
 func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) }
 func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) }
+func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) }
 
 
 func testGetBlockHeaders(t *testing.T, protocol int) {
 func testGetBlockHeaders(t *testing.T, protocol int) {
-	server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil)
+	server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+
+	bc := server.handler.blockchain
 
 
 	// Create a "random" unknown hash for testing
 	// Create a "random" unknown hash for testing
 	var unknown common.Hash
 	var unknown common.Hash
@@ -114,10 +116,10 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
 			[]common.Hash{bc.CurrentBlock().Hash()},
 			[]common.Hash{bc.CurrentBlock().Hash()},
 		},
 		},
 		// Ensure protocol limits are honored
 		// Ensure protocol limits are honored
-		/*{
-			&getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true},
-			bc.GetBlockHashesFromHash(bc.CurrentBlock().Hash(), limit),
-		},*/
+		//{
+		//	&getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true},
+		//	[]common.Hash{},
+		//},
 		// Check that requesting more than available is handled gracefully
 		// Check that requesting more than available is handled gracefully
 		{
 		{
 			&getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3},
 			&getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3},
@@ -165,9 +167,10 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
 		}
 		}
 		// Send the hash request and verify the response
 		// Send the hash request and verify the response
 		reqID++
 		reqID++
-		cost := server.tPeer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount))
-		sendRequest(server.tPeer.app, GetBlockHeadersMsg, reqID, cost, tt.query)
-		if err := expectResponse(server.tPeer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil {
+
+		cost := server.peer.peer.GetRequestCost(GetBlockHeadersMsg, int(tt.query.Amount))
+		sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, cost, tt.query)
+		if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, testBufLimit, headers); err != nil {
 			t.Errorf("test %d: headers mismatch: %v", i, err)
 			t.Errorf("test %d: headers mismatch: %v", i, err)
 		}
 		}
 	}
 	}
@@ -175,11 +178,13 @@ func testGetBlockHeaders(t *testing.T, protocol int) {
 
 
 // Tests that block contents can be retrieved from a remote chain based on their hashes.
 // Tests that block contents can be retrieved from a remote chain based on their hashes.
 func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) }
 func TestGetBlockBodiesLes2(t *testing.T) { testGetBlockBodies(t, 2) }
+func TestGetBlockBodiesLes3(t *testing.T) { testGetBlockBodies(t, 3) }
 
 
 func testGetBlockBodies(t *testing.T, protocol int) {
 func testGetBlockBodies(t *testing.T, protocol int) {
-	server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil)
+	server, tearDown := newServerEnv(t, downloader.MaxBlockFetch+15, protocol, nil, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+
+	bc := server.handler.blockchain
 
 
 	// Create a batch of tests for various scenarios
 	// Create a batch of tests for various scenarios
 	limit := MaxBodyFetch
 	limit := MaxBodyFetch
@@ -239,10 +244,11 @@ func testGetBlockBodies(t *testing.T, protocol int) {
 			}
 			}
 		}
 		}
 		reqID++
 		reqID++
+
 		// Send the hash request and verify the response
 		// Send the hash request and verify the response
-		cost := server.tPeer.GetRequestCost(GetBlockBodiesMsg, len(hashes))
-		sendRequest(server.tPeer.app, GetBlockBodiesMsg, reqID, cost, hashes)
-		if err := expectResponse(server.tPeer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil {
+		cost := server.peer.peer.GetRequestCost(GetBlockBodiesMsg, len(hashes))
+		sendRequest(server.peer.app, GetBlockBodiesMsg, reqID, cost, hashes)
+		if err := expectResponse(server.peer.app, BlockBodiesMsg, reqID, testBufLimit, bodies); err != nil {
 			t.Errorf("test %d: bodies mismatch: %v", i, err)
 			t.Errorf("test %d: bodies mismatch: %v", i, err)
 		}
 		}
 	}
 	}
@@ -250,12 +256,13 @@ func testGetBlockBodies(t *testing.T, protocol int) {
 
 
 // Tests that the contract codes can be retrieved based on account addresses.
 // Tests that the contract codes can be retrieved based on account addresses.
 func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) }
 func TestGetCodeLes2(t *testing.T) { testGetCode(t, 2) }
+func TestGetCodeLes3(t *testing.T) { testGetCode(t, 3) }
 
 
 func testGetCode(t *testing.T, protocol int) {
 func testGetCode(t *testing.T, protocol int) {
 	// Assemble the test environment
 	// Assemble the test environment
-	server, tearDown := newServerEnv(t, 4, protocol, nil)
+	server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+	bc := server.handler.blockchain
 
 
 	var codereqs []*CodeReq
 	var codereqs []*CodeReq
 	var codes [][]byte
 	var codes [][]byte
@@ -271,9 +278,9 @@ func testGetCode(t *testing.T, protocol int) {
 		}
 		}
 	}
 	}
 
 
-	cost := server.tPeer.GetRequestCost(GetCodeMsg, len(codereqs))
-	sendRequest(server.tPeer.app, GetCodeMsg, 42, cost, codereqs)
-	if err := expectResponse(server.tPeer.app, CodeMsg, 42, testBufLimit, codes); err != nil {
+	cost := server.peer.peer.GetRequestCost(GetCodeMsg, len(codereqs))
+	sendRequest(server.peer.app, GetCodeMsg, 42, cost, codereqs)
+	if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, codes); err != nil {
 		t.Errorf("codes mismatch: %v", err)
 		t.Errorf("codes mismatch: %v", err)
 	}
 	}
 }
 }
@@ -283,18 +290,18 @@ func TestGetStaleCodeLes2(t *testing.T) { testGetStaleCode(t, 2) }
 func TestGetStaleCodeLes3(t *testing.T) { testGetStaleCode(t, 3) }
 func TestGetStaleCodeLes3(t *testing.T) { testGetStaleCode(t, 3) }
 
 
 func testGetStaleCode(t *testing.T, protocol int) {
 func testGetStaleCode(t *testing.T, protocol int) {
-	server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil)
+	server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+	bc := server.handler.blockchain
 
 
 	check := func(number uint64, expected [][]byte) {
 	check := func(number uint64, expected [][]byte) {
 		req := &CodeReq{
 		req := &CodeReq{
 			BHash:  bc.GetHeaderByNumber(number).Hash(),
 			BHash:  bc.GetHeaderByNumber(number).Hash(),
 			AccKey: crypto.Keccak256(testContractAddr[:]),
 			AccKey: crypto.Keccak256(testContractAddr[:]),
 		}
 		}
-		cost := server.tPeer.GetRequestCost(GetCodeMsg, 1)
-		sendRequest(server.tPeer.app, GetCodeMsg, 42, cost, []*CodeReq{req})
-		if err := expectResponse(server.tPeer.app, CodeMsg, 42, testBufLimit, expected); err != nil {
+		cost := server.peer.peer.GetRequestCost(GetCodeMsg, 1)
+		sendRequest(server.peer.app, GetCodeMsg, 42, cost, []*CodeReq{req})
+		if err := expectResponse(server.peer.app, CodeMsg, 42, testBufLimit, expected); err != nil {
 			t.Errorf("codes mismatch: %v", err)
 			t.Errorf("codes mismatch: %v", err)
 		}
 		}
 	}
 	}
@@ -305,12 +312,14 @@ func testGetStaleCode(t *testing.T, protocol int) {
 
 
 // Tests that the transaction receipts can be retrieved based on hashes.
 // Tests that the transaction receipts can be retrieved based on hashes.
 func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) }
 func TestGetReceiptLes2(t *testing.T) { testGetReceipt(t, 2) }
+func TestGetReceiptLes3(t *testing.T) { testGetReceipt(t, 3) }
 
 
 func testGetReceipt(t *testing.T, protocol int) {
 func testGetReceipt(t *testing.T, protocol int) {
 	// Assemble the test environment
 	// Assemble the test environment
-	server, tearDown := newServerEnv(t, 4, protocol, nil)
+	server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+
+	bc := server.handler.blockchain
 
 
 	// Collect the hashes to request, and the response to expect
 	// Collect the hashes to request, and the response to expect
 	var receipts []types.Receipts
 	var receipts []types.Receipts
@@ -322,26 +331,28 @@ func testGetReceipt(t *testing.T, protocol int) {
 		receipts = append(receipts, rawdb.ReadRawReceipts(server.db, block.Hash(), block.NumberU64()))
 		receipts = append(receipts, rawdb.ReadRawReceipts(server.db, block.Hash(), block.NumberU64()))
 	}
 	}
 	// Send the hash request and verify the response
 	// Send the hash request and verify the response
-	cost := server.tPeer.GetRequestCost(GetReceiptsMsg, len(hashes))
-	sendRequest(server.tPeer.app, GetReceiptsMsg, 42, cost, hashes)
-	if err := expectResponse(server.tPeer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil {
+	cost := server.peer.peer.GetRequestCost(GetReceiptsMsg, len(hashes))
+	sendRequest(server.peer.app, GetReceiptsMsg, 42, cost, hashes)
+	if err := expectResponse(server.peer.app, ReceiptsMsg, 42, testBufLimit, receipts); err != nil {
 		t.Errorf("receipts mismatch: %v", err)
 		t.Errorf("receipts mismatch: %v", err)
 	}
 	}
 }
 }
 
 
 // Tests that trie merkle proofs can be retrieved
 // Tests that trie merkle proofs can be retrieved
 func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) }
 func TestGetProofsLes2(t *testing.T) { testGetProofs(t, 2) }
+func TestGetProofsLes3(t *testing.T) { testGetProofs(t, 3) }
 
 
 func testGetProofs(t *testing.T, protocol int) {
 func testGetProofs(t *testing.T, protocol int) {
 	// Assemble the test environment
 	// Assemble the test environment
-	server, tearDown := newServerEnv(t, 4, protocol, nil)
+	server, tearDown := newServerEnv(t, 4, protocol, nil, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+
+	bc := server.handler.blockchain
 
 
 	var proofreqs []ProofReq
 	var proofreqs []ProofReq
 	proofsV2 := light.NewNodeSet()
 	proofsV2 := light.NewNodeSet()
 
 
-	accounts := []common.Address{bankAddr, userAddr1, userAddr2, {}}
+	accounts := []common.Address{bankAddr, userAddr1, userAddr2, signerAddr, {}}
 	for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
 	for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
 		header := bc.GetHeaderByNumber(i)
 		header := bc.GetHeaderByNumber(i)
 		trie, _ := trie.New(header.Root, trie.NewDatabase(server.db))
 		trie, _ := trie.New(header.Root, trie.NewDatabase(server.db))
@@ -356,9 +367,9 @@ func testGetProofs(t *testing.T, protocol int) {
 		}
 		}
 	}
 	}
 	// Send the proof request and verify the response
 	// Send the proof request and verify the response
-	cost := server.tPeer.GetRequestCost(GetProofsV2Msg, len(proofreqs))
-	sendRequest(server.tPeer.app, GetProofsV2Msg, 42, cost, proofreqs)
-	if err := expectResponse(server.tPeer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil {
+	cost := server.peer.peer.GetRequestCost(GetProofsV2Msg, len(proofreqs))
+	sendRequest(server.peer.app, GetProofsV2Msg, 42, cost, proofreqs)
+	if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, proofsV2.NodeList()); err != nil {
 		t.Errorf("proofs mismatch: %v", err)
 		t.Errorf("proofs mismatch: %v", err)
 	}
 	}
 }
 }
@@ -368,9 +379,9 @@ func TestGetStaleProofLes2(t *testing.T) { testGetStaleProof(t, 2) }
 func TestGetStaleProofLes3(t *testing.T) { testGetStaleProof(t, 3) }
 func TestGetStaleProofLes3(t *testing.T) { testGetStaleProof(t, 3) }
 
 
 func testGetStaleProof(t *testing.T, protocol int) {
 func testGetStaleProof(t *testing.T, protocol int) {
-	server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil)
+	server, tearDown := newServerEnv(t, core.TriesInMemory+4, protocol, nil, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+	bc := server.handler.blockchain
 
 
 	check := func(number uint64, wantOK bool) {
 	check := func(number uint64, wantOK bool) {
 		var (
 		var (
@@ -381,8 +392,8 @@ func testGetStaleProof(t *testing.T, protocol int) {
 			BHash: header.Hash(),
 			BHash: header.Hash(),
 			Key:   account,
 			Key:   account,
 		}
 		}
-		cost := server.tPeer.GetRequestCost(GetProofsV2Msg, 1)
-		sendRequest(server.tPeer.app, GetProofsV2Msg, 42, cost, []*ProofReq{req})
+		cost := server.peer.peer.GetRequestCost(GetProofsV2Msg, 1)
+		sendRequest(server.peer.app, GetProofsV2Msg, 42, cost, []*ProofReq{req})
 
 
 		var expected []rlp.RawValue
 		var expected []rlp.RawValue
 		if wantOK {
 		if wantOK {
@@ -391,7 +402,7 @@ func testGetStaleProof(t *testing.T, protocol int) {
 			t.Prove(account, 0, proofsV2)
 			t.Prove(account, 0, proofsV2)
 			expected = proofsV2.NodeList()
 			expected = proofsV2.NodeList()
 		}
 		}
-		if err := expectResponse(server.tPeer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil {
+		if err := expectResponse(server.peer.app, ProofsV2Msg, 42, testBufLimit, expected); err != nil {
 			t.Errorf("codes mismatch: %v", err)
 			t.Errorf("codes mismatch: %v", err)
 		}
 		}
 	}
 	}
@@ -402,6 +413,7 @@ func testGetStaleProof(t *testing.T, protocol int) {
 
 
 // Tests that CHT proofs can be correctly retrieved.
 // Tests that CHT proofs can be correctly retrieved.
 func TestGetCHTProofsLes2(t *testing.T) { testGetCHTProofs(t, 2) }
 func TestGetCHTProofsLes2(t *testing.T) { testGetCHTProofs(t, 2) }
+func TestGetCHTProofsLes3(t *testing.T) { testGetCHTProofs(t, 3) }
 
 
 func testGetCHTProofs(t *testing.T, protocol int) {
 func testGetCHTProofs(t *testing.T, protocol int) {
 	config := light.TestServerIndexerConfig
 	config := light.TestServerIndexerConfig
@@ -415,9 +427,10 @@ func testGetCHTProofs(t *testing.T, protocol int) {
 			time.Sleep(10 * time.Millisecond)
 			time.Sleep(10 * time.Millisecond)
 		}
 		}
 	}
 	}
-	server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers)
+	server, tearDown := newServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+
+	bc := server.handler.blockchain
 
 
 	// Assemble the proofs from the different protocols
 	// Assemble the proofs from the different protocols
 	header := bc.GetHeaderByNumber(config.ChtSize - 1)
 	header := bc.GetHeaderByNumber(config.ChtSize - 1)
@@ -440,15 +453,18 @@ func testGetCHTProofs(t *testing.T, protocol int) {
 		AuxReq:  auxHeader,
 		AuxReq:  auxHeader,
 	}}
 	}}
 	// Send the proof request and verify the response
 	// Send the proof request and verify the response
-	cost := server.tPeer.GetRequestCost(GetHelperTrieProofsMsg, len(requestsV2))
-	sendRequest(server.tPeer.app, GetHelperTrieProofsMsg, 42, cost, requestsV2)
-	if err := expectResponse(server.tPeer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil {
+	cost := server.peer.peer.GetRequestCost(GetHelperTrieProofsMsg, len(requestsV2))
+	sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, cost, requestsV2)
+	if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofsV2); err != nil {
 		t.Errorf("proofs mismatch: %v", err)
 		t.Errorf("proofs mismatch: %v", err)
 	}
 	}
 }
 }
 
 
+func TestGetBloombitsProofsLes2(t *testing.T) { testGetBloombitsProofs(t, 2) }
+func TestGetBloombitsProofsLes3(t *testing.T) { testGetBloombitsProofs(t, 3) }
+
 // Tests that bloombits proofs can be correctly retrieved.
 // Tests that bloombits proofs can be correctly retrieved.
-func TestGetBloombitsProofs(t *testing.T) {
+func testGetBloombitsProofs(t *testing.T, protocol int) {
 	config := light.TestServerIndexerConfig
 	config := light.TestServerIndexerConfig
 
 
 	waitIndexers := func(cIndexer, bIndexer, btIndexer *core.ChainIndexer) {
 	waitIndexers := func(cIndexer, bIndexer, btIndexer *core.ChainIndexer) {
@@ -460,9 +476,10 @@ func TestGetBloombitsProofs(t *testing.T) {
 			time.Sleep(10 * time.Millisecond)
 			time.Sleep(10 * time.Millisecond)
 		}
 		}
 	}
 	}
-	server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), 2, waitIndexers)
+	server, tearDown := newServerEnv(t, int(config.BloomTrieSize+config.BloomTrieConfirms), protocol, waitIndexers, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	bc := server.pm.blockchain.(*core.BlockChain)
+
+	bc := server.handler.blockchain
 
 
 	// Request and verify each bit of the bloom bits proofs
 	// Request and verify each bit of the bloom bits proofs
 	for bit := 0; bit < 2048; bit++ {
 	for bit := 0; bit < 2048; bit++ {
@@ -485,43 +502,39 @@ func TestGetBloombitsProofs(t *testing.T) {
 		trie.Prove(key, 0, &proofs.Proofs)
 		trie.Prove(key, 0, &proofs.Proofs)
 
 
 		// Send the proof request and verify the response
 		// Send the proof request and verify the response
-		cost := server.tPeer.GetRequestCost(GetHelperTrieProofsMsg, len(requests))
-		sendRequest(server.tPeer.app, GetHelperTrieProofsMsg, 42, cost, requests)
-		if err := expectResponse(server.tPeer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil {
+		cost := server.peer.peer.GetRequestCost(GetHelperTrieProofsMsg, len(requests))
+		sendRequest(server.peer.app, GetHelperTrieProofsMsg, 42, cost, requests)
+		if err := expectResponse(server.peer.app, HelperTrieProofsMsg, 42, testBufLimit, proofs); err != nil {
 			t.Errorf("bit %d: proofs mismatch: %v", bit, err)
 			t.Errorf("bit %d: proofs mismatch: %v", bit, err)
 		}
 		}
 	}
 	}
 }
 }
 
 
-func TestTransactionStatusLes2(t *testing.T) {
-	server, tearDown := newServerEnv(t, 0, 2, nil)
+func TestTransactionStatusLes2(t *testing.T) { testTransactionStatus(t, 2) }
+func TestTransactionStatusLes3(t *testing.T) { testTransactionStatus(t, 3) }
+
+func testTransactionStatus(t *testing.T, protocol int) {
+	server, tearDown := newServerEnv(t, 0, protocol, nil, false, true, 0)
 	defer tearDown()
 	defer tearDown()
-	server.pm.addTxsSync = true
+	server.handler.addTxsSync = true
 
 
-	chain := server.pm.blockchain.(*core.BlockChain)
-	config := core.DefaultTxPoolConfig
-	config.Journal = ""
-	txpool := core.NewTxPool(config, params.TestChainConfig, chain)
-	server.pm.txpool = txpool
-	peer, _ := newTestPeer(t, "peer", 2, server.pm, true, 0)
-	defer peer.close()
+	chain := server.handler.blockchain
 
 
 	var reqID uint64
 	var reqID uint64
 
 
 	test := func(tx *types.Transaction, send bool, expStatus light.TxStatus) {
 	test := func(tx *types.Transaction, send bool, expStatus light.TxStatus) {
 		reqID++
 		reqID++
 		if send {
 		if send {
-			cost := server.tPeer.GetRequestCost(SendTxV2Msg, 1)
-			sendRequest(server.tPeer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx})
+			cost := server.peer.peer.GetRequestCost(SendTxV2Msg, 1)
+			sendRequest(server.peer.app, SendTxV2Msg, reqID, cost, types.Transactions{tx})
 		} else {
 		} else {
-			cost := server.tPeer.GetRequestCost(GetTxStatusMsg, 1)
-			sendRequest(server.tPeer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()})
+			cost := server.peer.peer.GetRequestCost(GetTxStatusMsg, 1)
+			sendRequest(server.peer.app, GetTxStatusMsg, reqID, cost, []common.Hash{tx.Hash()})
 		}
 		}
-		if err := expectResponse(server.tPeer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil {
+		if err := expectResponse(server.peer.app, TxStatusMsg, reqID, testBufLimit, []light.TxStatus{expStatus}); err != nil {
 			t.Errorf("transaction status mismatch")
 			t.Errorf("transaction status mismatch")
 		}
 		}
 	}
 	}
-
 	signer := types.HomesteadSigner{}
 	signer := types.HomesteadSigner{}
 
 
 	// test error status by sending an underpriced transaction
 	// test error status by sending an underpriced transaction
@@ -551,18 +564,22 @@ func TestTransactionStatusLes2(t *testing.T) {
 	}
 	}
 	// wait until TxPool processes the inserted block
 	// wait until TxPool processes the inserted block
 	for i := 0; i < 10; i++ {
 	for i := 0; i < 10; i++ {
-		if pending, _ := txpool.Stats(); pending == 1 {
+		if pending, _ := server.handler.txpool.Stats(); pending == 1 {
 			break
 			break
 		}
 		}
 		time.Sleep(100 * time.Millisecond)
 		time.Sleep(100 * time.Millisecond)
 	}
 	}
-	if pending, _ := txpool.Stats(); pending != 1 {
+	if pending, _ := server.handler.txpool.Stats(); pending != 1 {
 		t.Fatalf("pending count mismatch: have %d, want 1", pending)
 		t.Fatalf("pending count mismatch: have %d, want 1", pending)
 	}
 	}
+	// Discard new block announcement
+	msg, _ := server.peer.app.ReadMsg()
+	msg.Discard()
 
 
 	// check if their status is included now
 	// check if their status is included now
 	block1hash := rawdb.ReadCanonicalHash(server.db, 1)
 	block1hash := rawdb.ReadCanonicalHash(server.db, 1)
 	test(tx1, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}})
 	test(tx1, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}})
+
 	test(tx2, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}})
 	test(tx2, false, light.TxStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.LegacyTxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}})
 
 
 	// create a reorg that rolls them back
 	// create a reorg that rolls them back
@@ -572,46 +589,46 @@ func TestTransactionStatusLes2(t *testing.T) {
 	}
 	}
 	// wait until TxPool processes the reorg
 	// wait until TxPool processes the reorg
 	for i := 0; i < 10; i++ {
 	for i := 0; i < 10; i++ {
-		if pending, _ := txpool.Stats(); pending == 3 {
+		if pending, _ := server.handler.txpool.Stats(); pending == 3 {
 			break
 			break
 		}
 		}
 		time.Sleep(100 * time.Millisecond)
 		time.Sleep(100 * time.Millisecond)
 	}
 	}
-	if pending, _ := txpool.Stats(); pending != 3 {
+	if pending, _ := server.handler.txpool.Stats(); pending != 3 {
 		t.Fatalf("pending count mismatch: have %d, want 3", pending)
 		t.Fatalf("pending count mismatch: have %d, want 3", pending)
 	}
 	}
+	// Discard new block announcement
+	msg, _ = server.peer.app.ReadMsg()
+	msg.Discard()
+
 	// check if their status is pending again
 	// check if their status is pending again
 	test(tx1, false, light.TxStatus{Status: core.TxStatusPending})
 	test(tx1, false, light.TxStatus{Status: core.TxStatusPending})
 	test(tx2, false, light.TxStatus{Status: core.TxStatusPending})
 	test(tx2, false, light.TxStatus{Status: core.TxStatusPending})
 }
 }
 
 
 func TestStopResumeLes3(t *testing.T) {
 func TestStopResumeLes3(t *testing.T) {
-	db := rawdb.NewMemoryDatabase()
-	clock := &mclock.Simulated{}
-	testCost := testBufLimit / 10
-	pm, _, err := newTestProtocolManager(false, 0, nil, nil, nil, db, nil, 0, testCost, clock)
-	if err != nil {
-		t.Fatalf("Failed to create protocol manager: %v", err)
-	}
-	peer, _ := newTestPeer(t, "peer", 3, pm, true, testCost)
-	defer peer.close()
+	server, tearDown := newServerEnv(t, 0, 3, nil, true, true, testBufLimit/10)
+	defer tearDown()
 
 
-	expBuf := testBufLimit
-	var reqID uint64
+	server.handler.server.costTracker.testing = true
 
 
-	header := pm.blockchain.CurrentHeader()
+	var (
+		reqID    uint64
+		expBuf   = testBufLimit
+		testCost = testBufLimit / 10
+	)
+	header := server.handler.blockchain.CurrentHeader()
 	req := func() {
 	req := func() {
 		reqID++
 		reqID++
-		sendRequest(peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1})
+		sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, testCost, &getBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1})
 	}
 	}
-
 	for i := 1; i <= 5; i++ {
 	for i := 1; i <= 5; i++ {
 		// send requests while we still have enough buffer and expect a response
 		// send requests while we still have enough buffer and expect a response
 		for expBuf >= testCost {
 		for expBuf >= testCost {
 			req()
 			req()
 			expBuf -= testCost
 			expBuf -= testCost
-			if err := expectResponse(peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil {
-				t.Fatalf("expected response and failed: %v", err)
+			if err := expectResponse(server.peer.app, BlockHeadersMsg, reqID, expBuf, []*types.Header{header}); err != nil {
+				t.Errorf("expected response and failed: %v", err)
 			}
 			}
 		}
 		}
 		// send some more requests in excess and expect a single StopMsg
 		// send some more requests in excess and expect a single StopMsg
@@ -620,15 +637,16 @@ func TestStopResumeLes3(t *testing.T) {
 			req()
 			req()
 			c--
 			c--
 		}
 		}
-		if err := p2p.ExpectMsg(peer.app, StopMsg, nil); err != nil {
+		if err := p2p.ExpectMsg(server.peer.app, StopMsg, nil); err != nil {
 			t.Errorf("expected StopMsg and failed: %v", err)
 			t.Errorf("expected StopMsg and failed: %v", err)
 		}
 		}
 		// wait until the buffer is recharged by half of the limit
 		// wait until the buffer is recharged by half of the limit
 		wait := testBufLimit / testBufRecharge / 2
 		wait := testBufLimit / testBufRecharge / 2
-		clock.Run(time.Millisecond * time.Duration(wait))
+		server.clock.(*mclock.Simulated).Run(time.Millisecond * time.Duration(wait))
+
 		// expect a ResumeMsg with the partially recharged buffer value
 		// expect a ResumeMsg with the partially recharged buffer value
 		expBuf += testBufRecharge * wait
 		expBuf += testBufRecharge * wait
-		if err := p2p.ExpectMsg(peer.app, ResumeMsg, expBuf); err != nil {
+		if err := p2p.ExpectMsg(server.peer.app, ResumeMsg, expBuf); err != nil {
 			t.Errorf("expected ResumeMsg and failed: %v", err)
 			t.Errorf("expected ResumeMsg and failed: %v", err)
 		}
 		}
 	}
 	}

+ 63 - 27
les/metrics.go

@@ -22,31 +22,73 @@ import (
 )
 )
 
 
 var (
 var (
-	miscInPacketsMeter  = metrics.NewRegisteredMeter("les/misc/in/packets", nil)
-	miscInTrafficMeter  = metrics.NewRegisteredMeter("les/misc/in/traffic", nil)
-	miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets", nil)
-	miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic", nil)
-
-	connectionTimer = metrics.NewRegisteredTimer("les/connectionTime", nil)
-
-	totalConnectedGauge     = metrics.NewRegisteredGauge("les/server/totalConnected", nil)
-	totalCapacityGauge      = metrics.NewRegisteredGauge("les/server/totalCapacity", nil)
-	totalRechargeGauge      = metrics.NewRegisteredGauge("les/server/totalRecharge", nil)
-	blockProcessingTimer    = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil)
-	requestServedTimer      = metrics.NewRegisteredTimer("les/server/requestServed", nil)
-	requestServedMeter      = metrics.NewRegisteredMeter("les/server/totalRequestServed", nil)
-	requestEstimatedMeter   = metrics.NewRegisteredMeter("les/server/totalRequestEstimated", nil)
-	relativeCostHistogram   = metrics.NewRegisteredHistogram("les/server/relativeCost", nil, metrics.NewExpDecaySample(1028, 0.015))
-	recentServedGauge       = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil)
-	recentEstimatedGauge    = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil)
-	sqServedGauge           = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil)
-	sqQueuedGauge           = metrics.NewRegisteredGauge("les/server/servingQueue/queued", nil)
+	miscInPacketsMeter           = metrics.NewRegisteredMeter("les/misc/in/packets/total", nil)
+	miscInTrafficMeter           = metrics.NewRegisteredMeter("les/misc/in/traffic/total", nil)
+	miscInHeaderPacketsMeter     = metrics.NewRegisteredMeter("les/misc/in/packets/header", nil)
+	miscInHeaderTrafficMeter     = metrics.NewRegisteredMeter("les/misc/in/traffic/header", nil)
+	miscInBodyPacketsMeter       = metrics.NewRegisteredMeter("les/misc/in/packets/body", nil)
+	miscInBodyTrafficMeter       = metrics.NewRegisteredMeter("les/misc/in/traffic/body", nil)
+	miscInCodePacketsMeter       = metrics.NewRegisteredMeter("les/misc/in/packets/code", nil)
+	miscInCodeTrafficMeter       = metrics.NewRegisteredMeter("les/misc/in/traffic/code", nil)
+	miscInReceiptPacketsMeter    = metrics.NewRegisteredMeter("les/misc/in/packets/receipt", nil)
+	miscInReceiptTrafficMeter    = metrics.NewRegisteredMeter("les/misc/in/traffic/receipt", nil)
+	miscInTrieProofPacketsMeter  = metrics.NewRegisteredMeter("les/misc/in/packets/proof", nil)
+	miscInTrieProofTrafficMeter  = metrics.NewRegisteredMeter("les/misc/in/traffic/proof", nil)
+	miscInHelperTriePacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/helperTrie", nil)
+	miscInHelperTrieTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/helperTrie", nil)
+	miscInTxsPacketsMeter        = metrics.NewRegisteredMeter("les/misc/in/packets/txs", nil)
+	miscInTxsTrafficMeter        = metrics.NewRegisteredMeter("les/misc/in/traffic/txs", nil)
+	miscInTxStatusPacketsMeter   = metrics.NewRegisteredMeter("les/misc/in/packets/txStatus", nil)
+	miscInTxStatusTrafficMeter   = metrics.NewRegisteredMeter("les/misc/in/traffic/txStatus", nil)
+
+	miscOutPacketsMeter           = metrics.NewRegisteredMeter("les/misc/out/packets/total", nil)
+	miscOutTrafficMeter           = metrics.NewRegisteredMeter("les/misc/out/traffic/total", nil)
+	miscOutHeaderPacketsMeter     = metrics.NewRegisteredMeter("les/misc/out/packets/header", nil)
+	miscOutHeaderTrafficMeter     = metrics.NewRegisteredMeter("les/misc/out/traffic/header", nil)
+	miscOutBodyPacketsMeter       = metrics.NewRegisteredMeter("les/misc/out/packets/body", nil)
+	miscOutBodyTrafficMeter       = metrics.NewRegisteredMeter("les/misc/out/traffic/body", nil)
+	miscOutCodePacketsMeter       = metrics.NewRegisteredMeter("les/misc/out/packets/code", nil)
+	miscOutCodeTrafficMeter       = metrics.NewRegisteredMeter("les/misc/out/traffic/code", nil)
+	miscOutReceiptPacketsMeter    = metrics.NewRegisteredMeter("les/misc/out/packets/receipt", nil)
+	miscOutReceiptTrafficMeter    = metrics.NewRegisteredMeter("les/misc/out/traffic/receipt", nil)
+	miscOutTrieProofPacketsMeter  = metrics.NewRegisteredMeter("les/misc/out/packets/proof", nil)
+	miscOutTrieProofTrafficMeter  = metrics.NewRegisteredMeter("les/misc/out/traffic/proof", nil)
+	miscOutHelperTriePacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/helperTrie", nil)
+	miscOutHelperTrieTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/helperTrie", nil)
+	miscOutTxsPacketsMeter        = metrics.NewRegisteredMeter("les/misc/out/packets/txs", nil)
+	miscOutTxsTrafficMeter        = metrics.NewRegisteredMeter("les/misc/out/traffic/txs", nil)
+	miscOutTxStatusPacketsMeter   = metrics.NewRegisteredMeter("les/misc/out/packets/txStatus", nil)
+	miscOutTxStatusTrafficMeter   = metrics.NewRegisteredMeter("les/misc/out/traffic/txStatus", nil)
+
+	connectionTimer       = metrics.NewRegisteredTimer("les/connection/duration", nil)
+	serverConnectionGauge = metrics.NewRegisteredGauge("les/connection/server", nil)
+	clientConnectionGauge = metrics.NewRegisteredGauge("les/connection/client", nil)
+
+	totalCapacityGauge   = metrics.NewRegisteredGauge("les/server/totalCapacity", nil)
+	totalRechargeGauge   = metrics.NewRegisteredGauge("les/server/totalRecharge", nil)
+	totalConnectedGauge  = metrics.NewRegisteredGauge("les/server/totalConnected", nil)
+	blockProcessingTimer = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil)
+
+	requestServedMeter    = metrics.NewRegisteredMeter("les/server/req/avgServedTime", nil)
+	requestServedTimer    = metrics.NewRegisteredTimer("les/server/req/servedTime", nil)
+	requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/req/avgEstimatedTime", nil)
+	requestEstimatedTimer = metrics.NewRegisteredTimer("les/server/req/estimatedTime", nil)
+	relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/req/relative", nil, metrics.NewExpDecaySample(1028, 0.015))
+
+	recentServedGauge    = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil)
+	recentEstimatedGauge = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil)
+	sqServedGauge        = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil)
+	sqQueuedGauge        = metrics.NewRegisteredGauge("les/server/servingQueue/queued", nil)
+
 	clientConnectedMeter    = metrics.NewRegisteredMeter("les/server/clientEvent/connected", nil)
 	clientConnectedMeter    = metrics.NewRegisteredMeter("les/server/clientEvent/connected", nil)
 	clientRejectedMeter     = metrics.NewRegisteredMeter("les/server/clientEvent/rejected", nil)
 	clientRejectedMeter     = metrics.NewRegisteredMeter("les/server/clientEvent/rejected", nil)
 	clientKickedMeter       = metrics.NewRegisteredMeter("les/server/clientEvent/kicked", nil)
 	clientKickedMeter       = metrics.NewRegisteredMeter("les/server/clientEvent/kicked", nil)
 	clientDisconnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/disconnected", nil)
 	clientDisconnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/disconnected", nil)
 	clientFreezeMeter       = metrics.NewRegisteredMeter("les/server/clientEvent/freeze", nil)
 	clientFreezeMeter       = metrics.NewRegisteredMeter("les/server/clientEvent/freeze", nil)
 	clientErrorMeter        = metrics.NewRegisteredMeter("les/server/clientEvent/error", nil)
 	clientErrorMeter        = metrics.NewRegisteredMeter("les/server/clientEvent/error", nil)
+
+	requestRTT       = metrics.NewRegisteredTimer("les/client/req/rtt", nil)
+	requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil)
 )
 )
 
 
 // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of
 // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of
@@ -58,17 +100,11 @@ type meteredMsgReadWriter struct {
 
 
 // newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the
 // newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the
 // metrics system is disabled, this function returns the original object.
 // metrics system is disabled, this function returns the original object.
-func newMeteredMsgWriter(rw p2p.MsgReadWriter) p2p.MsgReadWriter {
+func newMeteredMsgWriter(rw p2p.MsgReadWriter, version int) p2p.MsgReadWriter {
 	if !metrics.Enabled {
 	if !metrics.Enabled {
 		return rw
 		return rw
 	}
 	}
-	return &meteredMsgReadWriter{MsgReadWriter: rw}
-}
-
-// Init sets the protocol version used by the stream to know which meters to
-// increment in case of overlapping message ids between protocol versions.
-func (rw *meteredMsgReadWriter) Init(version int) {
-	rw.version = version
+	return &meteredMsgReadWriter{MsgReadWriter: rw, version: version}
 }
 }
 
 
 func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) {
 func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) {

+ 4 - 1
les/odr.go

@@ -18,7 +18,9 @@ package les
 
 
 import (
 import (
 	"context"
 	"context"
+	"time"
 
 
+	"github.com/ethereum/go-ethereum/common/mclock"
 	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/light"
 	"github.com/ethereum/go-ethereum/light"
@@ -120,10 +122,11 @@ func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err erro
 			return func() { lreq.Request(reqID, p) }
 			return func() { lreq.Request(reqID, p) }
 		},
 		},
 	}
 	}
-
+	sent := mclock.Now()
 	if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil {
 	if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil {
 		// retrieved from network, store in db
 		// retrieved from network, store in db
 		req.StoreResult(odr.db)
 		req.StoreResult(odr.db)
+		requestRTT.Update(time.Duration(mclock.Now() - sent))
 	} else {
 	} else {
 		log.Debug("Failed to retrieve data from network", "err", err)
 		log.Debug("Failed to retrieve data from network", "err", err)
 	}
 	}

+ 22 - 16
les/odr_test.go

@@ -39,6 +39,7 @@ import (
 type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte
 type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte
 
 
 func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) }
 func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) }
+func TestOdrGetBlockLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetBlock) }
 
 
 func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
 func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
 	var block *types.Block
 	var block *types.Block
@@ -55,6 +56,7 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon
 }
 }
 
 
 func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) }
 func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) }
+func TestOdrGetReceiptsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetReceipts) }
 
 
 func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
 func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
 	var receipts types.Receipts
 	var receipts types.Receipts
@@ -75,6 +77,7 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain
 }
 }
 
 
 func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) }
 func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) }
+func TestOdrAccountsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrAccounts) }
 
 
 func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
 func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
 	dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
 	dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
@@ -103,6 +106,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
 }
 }
 
 
 func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) }
 func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) }
+func TestOdrContractCallLes3(t *testing.T) { testOdr(t, 3, 2, true, odrContractCall) }
 
 
 type callmsg struct {
 type callmsg struct {
 	types.Message
 	types.Message
@@ -152,6 +156,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
 }
 }
 
 
 func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) }
 func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) }
+func TestOdrTxStatusLes3(t *testing.T) { testOdr(t, 3, 1, false, odrTxStatus) }
 
 
 func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
 func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
 	var txs types.Transactions
 	var txs types.Transactions
@@ -178,21 +183,22 @@ func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainCon
 // testOdr tests odr requests whose validation guaranteed by block headers.
 // testOdr tests odr requests whose validation guaranteed by block headers.
 func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn odrTestFn) {
 func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn odrTestFn) {
 	// Assemble the test environment
 	// Assemble the test environment
-	server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, true)
+	server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true)
 	defer tearDown()
 	defer tearDown()
-	client.pm.synchronise(client.rPeer)
+
+	client.handler.synchronise(client.peer.peer)
 
 
 	test := func(expFail uint64) {
 	test := func(expFail uint64) {
 		// Mark this as a helper to put the failures at the correct lines
 		// Mark this as a helper to put the failures at the correct lines
 		t.Helper()
 		t.Helper()
 
 
-		for i := uint64(0); i <= server.pm.blockchain.CurrentHeader().Number.Uint64(); i++ {
+		for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ {
 			bhash := rawdb.ReadCanonicalHash(server.db, i)
 			bhash := rawdb.ReadCanonicalHash(server.db, i)
-			b1 := fn(light.NoOdr, server.db, server.pm.chainConfig, server.pm.blockchain.(*core.BlockChain), nil, bhash)
+			b1 := fn(light.NoOdr, server.db, server.handler.server.chainConfig, server.handler.blockchain, nil, bhash)
 
 
 			ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
 			ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
-			defer cancel()
-			b2 := fn(ctx, client.db, client.pm.chainConfig, nil, client.pm.blockchain.(*light.LightChain), bhash)
+			b2 := fn(ctx, client.db, client.handler.backend.chainConfig, nil, client.handler.backend.blockchain, bhash)
+			cancel()
 
 
 			eq := bytes.Equal(b1, b2)
 			eq := bytes.Equal(b1, b2)
 			exp := i < expFail
 			exp := i < expFail
@@ -204,22 +210,22 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od
 			}
 			}
 		}
 		}
 	}
 	}
-	// temporarily remove peer to test odr fails
+
 	// expect retrievals to fail (except genesis block) without a les peer
 	// expect retrievals to fail (except genesis block) without a les peer
-	client.peers.Unregister(client.rPeer.id)
-	time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
+	client.handler.backend.peers.lock.Lock()
+	client.peer.peer.hasBlock = func(common.Hash, uint64, bool) bool { return false }
+	client.handler.backend.peers.lock.Unlock()
 	test(expFail)
 	test(expFail)
 
 
 	// expect all retrievals to pass
 	// expect all retrievals to pass
-	client.peers.Register(client.rPeer)
-	time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
-	client.peers.lock.Lock()
-	client.rPeer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
-	client.peers.lock.Unlock()
+	client.handler.backend.peers.lock.Lock()
+	client.peer.peer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
+	client.handler.backend.peers.lock.Unlock()
 	test(5)
 	test(5)
+
+	// still expect all retrievals to pass, now data should be cached locally
 	if checkCached {
 	if checkCached {
-		// still expect all retrievals to pass, now data should be cached locally
-		client.peers.Unregister(client.rPeer.id)
+		client.handler.backend.peers.Unregister(client.peer.peer.id)
 		time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
 		time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
 		test(5)
 		test(5)
 	}
 	}

+ 25 - 23
les/peer.go

@@ -111,7 +111,7 @@ type peer struct {
 	fcServer       *flowcontrol.ServerNode // nil if the peer is client only
 	fcServer       *flowcontrol.ServerNode // nil if the peer is client only
 	fcParams       flowcontrol.ServerParams
 	fcParams       flowcontrol.ServerParams
 	fcCosts        requestCostTable
 	fcCosts        requestCostTable
-	balanceTracker *balanceTracker // set by clientPool.connect, used and removed by ProtocolManager.handle
+	balanceTracker *balanceTracker // set by clientPool.connect, used and removed by serverHandler.
 
 
 	trusted                 bool
 	trusted                 bool
 	onlyAnnounce            bool
 	onlyAnnounce            bool
@@ -291,6 +291,11 @@ func (p *peer) updateCapacity(cap uint64) {
 	p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) })
 	p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) })
 }
 }
 
 
+func (p *peer) responseID() uint64 {
+	p.responseCount += 1
+	return p.responseCount
+}
+
 func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error {
 func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error {
 	type req struct {
 	type req struct {
 		ReqID uint64
 		ReqID uint64
@@ -373,6 +378,7 @@ func (p *peer) HasBlock(hash common.Hash, number uint64, hasState bool) bool {
 	}
 	}
 	hasBlock := p.hasBlock
 	hasBlock := p.hasBlock
 	p.lock.RUnlock()
 	p.lock.RUnlock()
+
 	return head >= number && number >= since && (recent == 0 || number+recent+4 > head) && hasBlock != nil && hasBlock(hash, number, hasState)
 	return head >= number && number >= since && (recent == 0 || number+recent+4 > head) && hasBlock != nil && hasBlock(hash, number, hasState)
 }
 }
 
 
@@ -571,6 +577,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
 	defer p.lock.Unlock()
 	defer p.lock.Unlock()
 
 
 	var send keyValueList
 	var send keyValueList
+
+	// Add some basic handshake fields
 	send = send.add("protocolVersion", uint64(p.version))
 	send = send.add("protocolVersion", uint64(p.version))
 	send = send.add("networkId", p.network)
 	send = send.add("networkId", p.network)
 	send = send.add("headTd", td)
 	send = send.add("headTd", td)
@@ -578,7 +586,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
 	send = send.add("headNum", headNum)
 	send = send.add("headNum", headNum)
 	send = send.add("genesisHash", genesis)
 	send = send.add("genesisHash", genesis)
 	if server != nil {
 	if server != nil {
-		if !server.onlyAnnounce {
+		// Add some information which services server can offer.
+		if !server.config.UltraLightOnlyAnnounce {
 			send = send.add("serveHeaders", nil)
 			send = send.add("serveHeaders", nil)
 			send = send.add("serveChainSince", uint64(0))
 			send = send.add("serveChainSince", uint64(0))
 			send = send.add("serveStateSince", uint64(0))
 			send = send.add("serveStateSince", uint64(0))
@@ -594,25 +603,28 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
 		}
 		}
 		send = send.add("flowControl/BL", server.defParams.BufLimit)
 		send = send.add("flowControl/BL", server.defParams.BufLimit)
 		send = send.add("flowControl/MRR", server.defParams.MinRecharge)
 		send = send.add("flowControl/MRR", server.defParams.MinRecharge)
+
 		var costList RequestCostList
 		var costList RequestCostList
-		if server.costTracker != nil {
-			costList = server.costTracker.makeCostList(server.costTracker.globalFactor())
+		if server.costTracker.testCostList != nil {
+			costList = server.costTracker.testCostList
 		} else {
 		} else {
-			costList = testCostList(server.testCost)
+			costList = server.costTracker.makeCostList(server.costTracker.globalFactor())
 		}
 		}
 		send = send.add("flowControl/MRC", costList)
 		send = send.add("flowControl/MRC", costList)
 		p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)])
 		p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)])
 		p.fcParams = server.defParams
 		p.fcParams = server.defParams
 
 
-		if server.protocolManager != nil && server.protocolManager.reg != nil && server.protocolManager.reg.isRunning() {
-			cp, height := server.protocolManager.reg.stableCheckpoint()
+		// Add advertised checkpoint and register block height which
+		// client can verify the checkpoint validity.
+		if server.oracle != nil && server.oracle.isRunning() {
+			cp, height := server.oracle.stableCheckpoint()
 			if cp != nil {
 			if cp != nil {
 				send = send.add("checkpoint/value", cp)
 				send = send.add("checkpoint/value", cp)
 				send = send.add("checkpoint/registerHeight", height)
 				send = send.add("checkpoint/registerHeight", height)
 			}
 			}
 		}
 		}
 	} else {
 	} else {
-		//on client node
+		// Add some client-specific handshake fields
 		p.announceType = announceTypeSimple
 		p.announceType = announceTypeSimple
 		if p.trusted {
 		if p.trusted {
 			p.announceType = announceTypeSigned
 			p.announceType = announceTypeSigned
@@ -663,17 +675,12 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
 	}
 	}
 
 
 	if server != nil {
 	if server != nil {
-		// until we have a proper peer connectivity API, allow LES connection to other servers
-		/*if recv.get("serveStateSince", nil) == nil {
-			return errResp(ErrUselessPeer, "wanted client, got server")
-		}*/
 		if recv.get("announceType", &p.announceType) != nil {
 		if recv.get("announceType", &p.announceType) != nil {
-			//set default announceType on server side
+			// set default announceType on server side
 			p.announceType = announceTypeSimple
 			p.announceType = announceTypeSimple
 		}
 		}
 		p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams)
 		p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams)
 	} else {
 	} else {
-		//mark OnlyAnnounce server if "serveHeaders", "serveChainSince", "serveStateSince" or "txRelay" fields don't exist
 		if recv.get("serveChainSince", &p.chainSince) != nil {
 		if recv.get("serveChainSince", &p.chainSince) != nil {
 			p.onlyAnnounce = true
 			p.onlyAnnounce = true
 		}
 		}
@@ -730,15 +737,10 @@ func (p *peer) updateFlowControl(update keyValueMap) {
 	if p.fcServer == nil {
 	if p.fcServer == nil {
 		return
 		return
 	}
 	}
-	params := p.fcParams
-	updateParams := false
-	if update.get("flowControl/BL", &params.BufLimit) == nil {
-		updateParams = true
-	}
-	if update.get("flowControl/MRR", &params.MinRecharge) == nil {
-		updateParams = true
-	}
-	if updateParams {
+	// If any of the flow control params is nil, refuse to update.
+	var params flowcontrol.ServerParams
+	if update.get("flowControl/BL", &params.BufLimit) == nil && update.get("flowControl/MRR", &params.MinRecharge) == nil {
+		// todo can light client set a minimal acceptable flow control params?
 		p.fcParams = params
 		p.fcParams = params
 		p.fcServer.UpdateParams(params)
 		p.fcServer.UpdateParams(params)
 	}
 	}

+ 38 - 42
les/peer_test.go

@@ -18,47 +18,54 @@ package les
 
 
 import (
 import (
 	"math/big"
 	"math/big"
+	"net"
 	"testing"
 	"testing"
 
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/common/mclock"
 	"github.com/ethereum/go-ethereum/common/mclock"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/eth"
 	"github.com/ethereum/go-ethereum/les/flowcontrol"
 	"github.com/ethereum/go-ethereum/les/flowcontrol"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/rlp"
 	"github.com/ethereum/go-ethereum/rlp"
 )
 )
 
 
-const (
-	test_networkid   = 10
-	protocol_version = lpv2
-)
+const protocolVersion = lpv2
 
 
 var (
 var (
-	hash    = common.HexToHash("some string")
-	genesis = common.HexToHash("genesis hash")
+	hash    = common.HexToHash("deadbeef")
+	genesis = common.HexToHash("cafebabe")
 	headNum = uint64(1234)
 	headNum = uint64(1234)
 	td      = big.NewInt(123)
 	td      = big.NewInt(123)
 )
 )
 
 
-//ulc connects to trusted peer and send announceType=announceTypeSigned
+func newNodeID(t *testing.T) *enode.Node {
+	key, err := crypto.GenerateKey()
+	if err != nil {
+		t.Fatal("generate key err:", err)
+	}
+	return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000)
+}
+
+// ulc connects to trusted peer and send announceType=announceTypeSigned
 func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testing.T) {
 func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testing.T) {
 	id := newNodeID(t).ID()
 	id := newNodeID(t).ID()
 
 
-	//peer to connect(on ulc side)
+	// peer to connect(on ulc side)
 	p := peer{
 	p := peer{
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
-		version: protocol_version,
+		version: protocolVersion,
 		trusted: true,
 		trusted: true,
 		rw: &rwStub{
 		rw: &rwStub{
 			WriteHook: func(recvList keyValueList) {
 			WriteHook: func(recvList keyValueList) {
-				//checking that ulc sends to peer allowedRequests=onlyAnnounceRequests and announceType = announceTypeSigned
 				recv, _ := recvList.decode()
 				recv, _ := recvList.decode()
 				var reqType uint64
 				var reqType uint64
-
 				err := recv.get("announceType", &reqType)
 				err := recv.get("announceType", &reqType)
 				if err != nil {
 				if err != nil {
 					t.Fatal(err)
 					t.Fatal(err)
 				}
 				}
-
 				if reqType != announceTypeSigned {
 				if reqType != announceTypeSigned {
 					t.Fatal("Expected announceTypeSigned")
 					t.Fatal("Expected announceTypeSigned")
 				}
 				}
@@ -71,18 +78,15 @@ func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testi
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/MRR", uint64(0))
 				l = l.add("flowControl/MRR", uint64(0))
 				l = l.add("flowControl/MRC", testCostList(0))
 				l = l.add("flowControl/MRC", testCostList(0))
-
 				return l
 				return l
 			},
 			},
 		},
 		},
-		network: test_networkid,
+		network: NetworkId,
 	}
 	}
-
 	err := p.Handshake(td, hash, headNum, genesis, nil)
 	err := p.Handshake(td, hash, headNum, genesis, nil)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Handshake error: %s", err)
 		t.Fatalf("Handshake error: %s", err)
 	}
 	}
-
 	if p.announceType != announceTypeSigned {
 	if p.announceType != announceTypeSigned {
 		t.Fatal("Incorrect announceType")
 		t.Fatal("Incorrect announceType")
 	}
 	}
@@ -92,18 +96,16 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi
 	id := newNodeID(t).ID()
 	id := newNodeID(t).ID()
 	p := peer{
 	p := peer{
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
-		version: protocol_version,
+		version: protocolVersion,
 		rw: &rwStub{
 		rw: &rwStub{
 			WriteHook: func(recvList keyValueList) {
 			WriteHook: func(recvList keyValueList) {
-				//checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned
+				// checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned
 				recv, _ := recvList.decode()
 				recv, _ := recvList.decode()
 				var reqType uint64
 				var reqType uint64
-
 				err := recv.get("announceType", &reqType)
 				err := recv.get("announceType", &reqType)
 				if err != nil {
 				if err != nil {
 					t.Fatal(err)
 					t.Fatal(err)
 				}
 				}
-
 				if reqType == announceTypeSigned {
 				if reqType == announceTypeSigned {
 					t.Fatal("Expected not announceTypeSigned")
 					t.Fatal("Expected not announceTypeSigned")
 				}
 				}
@@ -116,13 +118,11 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/MRR", uint64(0))
 				l = l.add("flowControl/MRR", uint64(0))
 				l = l.add("flowControl/MRC", testCostList(0))
 				l = l.add("flowControl/MRC", testCostList(0))
-
 				return l
 				return l
 			},
 			},
 		},
 		},
-		network: test_networkid,
+		network: NetworkId,
 	}
 	}
-
 	err := p.Handshake(td, hash, headNum, genesis, nil)
 	err := p.Handshake(td, hash, headNum, genesis, nil)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
@@ -139,16 +139,15 @@ func TestPeerHandshakeDefaultAllRequests(t *testing.T) {
 
 
 	p := peer{
 	p := peer{
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
-		version: protocol_version,
+		version: protocolVersion,
 		rw: &rwStub{
 		rw: &rwStub{
 			ReadHook: func(l keyValueList) keyValueList {
 			ReadHook: func(l keyValueList) keyValueList {
 				l = l.add("announceType", uint64(announceTypeSigned))
 				l = l.add("announceType", uint64(announceTypeSigned))
 				l = l.add("allowedRequests", uint64(0))
 				l = l.add("allowedRequests", uint64(0))
-
 				return l
 				return l
 			},
 			},
 		},
 		},
-		network: test_networkid,
+		network: NetworkId,
 	}
 	}
 
 
 	err := p.Handshake(td, hash, headNum, genesis, s)
 	err := p.Handshake(td, hash, headNum, genesis, s)
@@ -165,15 +164,14 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) {
 	id := newNodeID(t).ID()
 	id := newNodeID(t).ID()
 
 
 	s := generateLesServer()
 	s := generateLesServer()
-	s.onlyAnnounce = true
+	s.config.UltraLightOnlyAnnounce = true
 
 
 	p := peer{
 	p := peer{
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
-		version: protocol_version,
+		version: protocolVersion,
 		rw: &rwStub{
 		rw: &rwStub{
 			ReadHook: func(l keyValueList) keyValueList {
 			ReadHook: func(l keyValueList) keyValueList {
 				l = l.add("announceType", uint64(announceTypeSigned))
 				l = l.add("announceType", uint64(announceTypeSigned))
-
 				return l
 				return l
 			},
 			},
 			WriteHook: func(l keyValueList) {
 			WriteHook: func(l keyValueList) {
@@ -187,7 +185,7 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) {
 				}
 				}
 			},
 			},
 		},
 		},
-		network: test_networkid,
+		network: NetworkId,
 	}
 	}
 
 
 	err := p.Handshake(td, hash, headNum, genesis, s)
 	err := p.Handshake(td, hash, headNum, genesis, s)
@@ -200,7 +198,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) {
 
 
 	p := peer{
 	p := peer{
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
-		version: protocol_version,
+		version: protocolVersion,
 		rw: &rwStub{
 		rw: &rwStub{
 			ReadHook: func(l keyValueList) keyValueList {
 			ReadHook: func(l keyValueList) keyValueList {
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/BL", uint64(0))
@@ -212,7 +210,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) {
 				return l
 				return l
 			},
 			},
 		},
 		},
-		network: test_networkid,
+		network: NetworkId,
 		trusted: true,
 		trusted: true,
 	}
 	}
 
 
@@ -231,19 +229,17 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) {
 
 
 	p := peer{
 	p := peer{
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
 		Peer:    p2p.NewPeer(id, "test peer", []p2p.Cap{}),
-		version: protocol_version,
+		version: protocolVersion,
 		rw: &rwStub{
 		rw: &rwStub{
 			ReadHook: func(l keyValueList) keyValueList {
 			ReadHook: func(l keyValueList) keyValueList {
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/MRR", uint64(0))
 				l = l.add("flowControl/MRR", uint64(0))
 				l = l.add("flowControl/MRC", RequestCostList{})
 				l = l.add("flowControl/MRC", RequestCostList{})
-
 				l = l.add("announceType", uint64(announceTypeSigned))
 				l = l.add("announceType", uint64(announceTypeSigned))
-
 				return l
 				return l
 			},
 			},
 		},
 		},
-		network: test_networkid,
+		network: NetworkId,
 	}
 	}
 
 
 	err := p.Handshake(td, hash, headNum, genesis, nil)
 	err := p.Handshake(td, hash, headNum, genesis, nil)
@@ -254,12 +250,16 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) {
 
 
 func generateLesServer() *LesServer {
 func generateLesServer() *LesServer {
 	s := &LesServer{
 	s := &LesServer{
+		lesCommons: lesCommons{
+			config: &eth.Config{UltraLightOnlyAnnounce: true},
+		},
 		defParams: flowcontrol.ServerParams{
 		defParams: flowcontrol.ServerParams{
 			BufLimit:    uint64(300000000),
 			BufLimit:    uint64(300000000),
 			MinRecharge: uint64(50000),
 			MinRecharge: uint64(50000),
 		},
 		},
 		fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}),
 		fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}),
 	}
 	}
+	s.costTracker, _ = newCostTracker(rawdb.NewMemoryDatabase(), s.config)
 	return s
 	return s
 }
 }
 
 
@@ -270,8 +270,8 @@ type rwStub struct {
 
 
 func (s *rwStub) ReadMsg() (p2p.Msg, error) {
 func (s *rwStub) ReadMsg() (p2p.Msg, error) {
 	payload := keyValueList{}
 	payload := keyValueList{}
-	payload = payload.add("protocolVersion", uint64(protocol_version))
-	payload = payload.add("networkId", uint64(test_networkid))
+	payload = payload.add("protocolVersion", uint64(protocolVersion))
+	payload = payload.add("networkId", uint64(NetworkId))
 	payload = payload.add("headTd", td)
 	payload = payload.add("headTd", td)
 	payload = payload.add("headHash", hash)
 	payload = payload.add("headHash", hash)
 	payload = payload.add("headNum", headNum)
 	payload = payload.add("headNum", headNum)
@@ -280,12 +280,10 @@ func (s *rwStub) ReadMsg() (p2p.Msg, error) {
 	if s.ReadHook != nil {
 	if s.ReadHook != nil {
 		payload = s.ReadHook(payload)
 		payload = s.ReadHook(payload)
 	}
 	}
-
 	size, p, err := rlp.EncodeToReader(payload)
 	size, p, err := rlp.EncodeToReader(payload)
 	if err != nil {
 	if err != nil {
 		return p2p.Msg{}, err
 		return p2p.Msg{}, err
 	}
 	}
-
 	return p2p.Msg{
 	return p2p.Msg{
 		Size:    uint32(size),
 		Size:    uint32(size),
 		Payload: p,
 		Payload: p,
@@ -297,10 +295,8 @@ func (s *rwStub) WriteMsg(m p2p.Msg) error {
 	if err := m.Decode(&recvList); err != nil {
 	if err := m.Decode(&recvList); err != nil {
 		return err
 		return err
 	}
 	}
-
 	if s.WriteHook != nil {
 	if s.WriteHook != nil {
 		s.WriteHook(recvList)
 		s.WriteHook(recvList)
 	}
 	}
-
 	return nil
 	return nil
 }
 }

+ 10 - 18
les/request_test.go

@@ -37,18 +37,21 @@ func secAddr(addr common.Address) []byte {
 type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest
 type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest
 
 
 func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) }
 func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) }
+func TestBlockAccessLes3(t *testing.T) { testAccess(t, 3, tfBlockAccess) }
 
 
 func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
 func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
 	return &light.BlockRequest{Hash: bhash, Number: number}
 	return &light.BlockRequest{Hash: bhash, Number: number}
 }
 }
 
 
 func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) }
 func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) }
+func TestReceiptsAccessLes3(t *testing.T) { testAccess(t, 3, tfReceiptsAccess) }
 
 
 func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
 func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
 	return &light.ReceiptsRequest{Hash: bhash, Number: number}
 	return &light.ReceiptsRequest{Hash: bhash, Number: number}
 }
 }
 
 
 func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) }
 func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) }
+func TestTrieEntryAccessLes3(t *testing.T) { testAccess(t, 3, tfTrieEntryAccess) }
 
 
 func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
 func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
 	if number := rawdb.ReadHeaderNumber(db, bhash); number != nil {
 	if number := rawdb.ReadHeaderNumber(db, bhash); number != nil {
@@ -58,6 +61,7 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh
 }
 }
 
 
 func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) }
 func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) }
+func TestCodeAccessLes3(t *testing.T) { testAccess(t, 3, tfCodeAccess) }
 
 
 func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest {
 func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest {
 	number := rawdb.ReadHeaderNumber(db, bhash)
 	number := rawdb.ReadHeaderNumber(db, bhash)
@@ -75,17 +79,18 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrReq
 
 
 func testAccess(t *testing.T, protocol int, fn accessTestFn) {
 func testAccess(t *testing.T, protocol int, fn accessTestFn) {
 	// Assemble the test environment
 	// Assemble the test environment
-	server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, true)
+	server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true)
 	defer tearDown()
 	defer tearDown()
-	client.pm.synchronise(client.rPeer)
+	client.handler.synchronise(client.peer.peer)
 
 
 	test := func(expFail uint64) {
 	test := func(expFail uint64) {
-		for i := uint64(0); i <= server.pm.blockchain.CurrentHeader().Number.Uint64(); i++ {
+		for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ {
 			bhash := rawdb.ReadCanonicalHash(server.db, i)
 			bhash := rawdb.ReadCanonicalHash(server.db, i)
 			if req := fn(client.db, bhash, i); req != nil {
 			if req := fn(client.db, bhash, i); req != nil {
 				ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
 				ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
-				defer cancel()
-				err := client.pm.odr.Retrieve(ctx, req)
+				err := client.handler.backend.odr.Retrieve(ctx, req)
+				cancel()
+
 				got := err == nil
 				got := err == nil
 				exp := i < expFail
 				exp := i < expFail
 				if exp && !got {
 				if exp && !got {
@@ -97,18 +102,5 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) {
 			}
 			}
 		}
 		}
 	}
 	}
-
-	// temporarily remove peer to test odr fails
-	client.peers.Unregister(client.rPeer.id)
-	time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
-	// expect retrievals to fail (except genesis block) without a les peer
-	test(0)
-
-	client.peers.Register(client.rPeer)
-	time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
-	client.rPeer.lock.Lock()
-	client.rPeer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
-	client.rPeer.lock.Unlock()
-	// expect all retrievals to pass
 	test(5)
 	test(5)
 }
 }

+ 143 - 197
les/server.go

@@ -18,15 +18,11 @@ package les
 
 
 import (
 import (
 	"crypto/ecdsa"
 	"crypto/ecdsa"
-	"sync"
 	"time"
 	"time"
 
 
 	"github.com/ethereum/go-ethereum/accounts/abi/bind"
 	"github.com/ethereum/go-ethereum/accounts/abi/bind"
-	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/common/mclock"
 	"github.com/ethereum/go-ethereum/common/mclock"
 	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/core"
-	"github.com/ethereum/go-ethereum/core/rawdb"
-	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/eth"
 	"github.com/ethereum/go-ethereum/eth"
 	"github.com/ethereum/go-ethereum/les/flowcontrol"
 	"github.com/ethereum/go-ethereum/les/flowcontrol"
 	"github.com/ethereum/go-ethereum/light"
 	"github.com/ethereum/go-ethereum/light"
@@ -38,80 +34,94 @@ import (
 	"github.com/ethereum/go-ethereum/rpc"
 	"github.com/ethereum/go-ethereum/rpc"
 )
 )
 
 
-const bufLimitRatio = 6000 // fixed bufLimit/MRR ratio
-
 type LesServer struct {
 type LesServer struct {
 	lesCommons
 	lesCommons
 
 
 	archiveMode bool // Flag whether the ethereum node runs in archive mode.
 	archiveMode bool // Flag whether the ethereum node runs in archive mode.
+	handler     *serverHandler
+	lesTopics   []discv5.Topic
+	privateKey  *ecdsa.PrivateKey
 
 
-	fcManager    *flowcontrol.ClientManager // nil if our node is client only
+	// Flow control and capacity management
+	fcManager    *flowcontrol.ClientManager
 	costTracker  *costTracker
 	costTracker  *costTracker
-	testCost     uint64
 	defParams    flowcontrol.ServerParams
 	defParams    flowcontrol.ServerParams
-	lesTopics    []discv5.Topic
-	privateKey   *ecdsa.PrivateKey
-	quitSync     chan struct{}
-	onlyAnnounce bool
-
-	thcNormal, thcBlockProcessing int // serving thread count for normal operation and block processing mode
+	servingQueue *servingQueue
+	clientPool   *clientPool
 
 
-	maxPeers                                int
-	minCapacity, maxCapacity, freeClientCap uint64
-	clientPool                              *clientPool
+	freeCapacity uint64 // The minimal client capacity used for free client.
+	threadsIdle  int    // Request serving threads count when system is idle.
+	threadsBusy  int    // Request serving threads count when system is busy(block insertion).
 }
 }
 
 
 func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
 func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
+	// Collect les protocol version information supported by local node.
 	lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions))
 	lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions))
 	for i, pv := range AdvertiseProtocolVersions {
 	for i, pv := range AdvertiseProtocolVersions {
 		lesTopics[i] = lesTopic(e.BlockChain().Genesis().Hash(), pv)
 		lesTopics[i] = lesTopic(e.BlockChain().Genesis().Hash(), pv)
 	}
 	}
-	quitSync := make(chan struct{})
+	// Calculate the number of threads used to service the light client
+	// requests based on the user-specified value.
+	threads := config.LightServ * 4 / 100
+	if threads < 4 {
+		threads = 4
+	}
 	srv := &LesServer{
 	srv := &LesServer{
 		lesCommons: lesCommons{
 		lesCommons: lesCommons{
+			genesis:          e.BlockChain().Genesis().Hash(),
 			config:           config,
 			config:           config,
+			chainConfig:      e.BlockChain().Config(),
 			iConfig:          light.DefaultServerIndexerConfig,
 			iConfig:          light.DefaultServerIndexerConfig,
 			chainDb:          e.ChainDb(),
 			chainDb:          e.ChainDb(),
+			peers:            newPeerSet(),
+			chainReader:      e.BlockChain(),
 			chtIndexer:       light.NewChtIndexer(e.ChainDb(), nil, params.CHTFrequency, params.HelperTrieProcessConfirmations),
 			chtIndexer:       light.NewChtIndexer(e.ChainDb(), nil, params.CHTFrequency, params.HelperTrieProcessConfirmations),
 			bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency),
 			bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency),
+			closeCh:          make(chan struct{}),
 		},
 		},
 		archiveMode:  e.ArchiveMode(),
 		archiveMode:  e.ArchiveMode(),
-		quitSync:     quitSync,
 		lesTopics:    lesTopics,
 		lesTopics:    lesTopics,
-		onlyAnnounce: config.UltraLightOnlyAnnounce,
+		fcManager:    flowcontrol.NewClientManager(nil, &mclock.System{}),
+		servingQueue: newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100),
+		threadsBusy:  config.LightServ/100 + 1,
+		threadsIdle:  threads,
 	}
 	}
-	srv.costTracker, srv.minCapacity = newCostTracker(e.ChainDb(), config)
+	srv.handler = newServerHandler(srv, e.BlockChain(), e.ChainDb(), e.TxPool(), e.Synced)
+	srv.costTracker, srv.freeCapacity = newCostTracker(e.ChainDb(), config)
 
 
-	logger := log.New()
-	srv.thcNormal = config.LightServ * 4 / 100
-	if srv.thcNormal < 4 {
-		srv.thcNormal = 4
+	// Set up checkpoint oracle.
+	oracle := config.CheckpointOracle
+	if oracle == nil {
+		oracle = params.CheckpointOracles[e.BlockChain().Genesis().Hash()]
 	}
 	}
-	srv.thcBlockProcessing = config.LightServ/100 + 1
-	srv.fcManager = flowcontrol.NewClientManager(nil, &mclock.System{})
+	srv.oracle = newCheckpointOracle(oracle, srv.localCheckpoint)
+
+	// Initialize server capacity management fields.
+	srv.defParams = flowcontrol.ServerParams{
+		BufLimit:    srv.freeCapacity * bufLimitRatio,
+		MinRecharge: srv.freeCapacity,
+	}
+	// LES flow control tries to more or less guarantee the possibility for the
+	// clients to send a certain amount of requests at any time and get a quick
+	// response. Most of the clients want this guarantee but don't actually need
+	// to send requests most of the time. Our goal is to serve as many clients as
+	// possible while the actually used server capacity does not exceed the limits
+	totalRecharge := srv.costTracker.totalRecharge()
+	maxCapacity := srv.freeCapacity * uint64(srv.config.LightPeers)
+	if totalRecharge > maxCapacity {
+		maxCapacity = totalRecharge
+	}
+	srv.fcManager.SetCapacityLimits(srv.freeCapacity, maxCapacity, srv.freeCapacity*2)
+
+	srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, 10000, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) })
+	srv.peers.notify(srv.clientPool)
 
 
 	checkpoint := srv.latestLocalCheckpoint()
 	checkpoint := srv.latestLocalCheckpoint()
 	if !checkpoint.Empty() {
 	if !checkpoint.Empty() {
-		logger.Info("Loaded latest checkpoint", "section", checkpoint.SectionIndex, "head", checkpoint.SectionHead,
+		log.Info("Loaded latest checkpoint", "section", checkpoint.SectionIndex, "head", checkpoint.SectionHead,
 			"chtroot", checkpoint.CHTRoot, "bloomroot", checkpoint.BloomRoot)
 			"chtroot", checkpoint.CHTRoot, "bloomroot", checkpoint.BloomRoot)
 	}
 	}
-
 	srv.chtIndexer.Start(e.BlockChain())
 	srv.chtIndexer.Start(e.BlockChain())
-
-	oracle := config.CheckpointOracle
-	if oracle == nil {
-		oracle = params.CheckpointOracles[e.BlockChain().Genesis().Hash()]
-	}
-	registrar := newCheckpointOracle(oracle, srv.getLocalCheckpoint)
-	// TODO(rjl493456442) Checkpoint is useless for les server, separate handler for client and server.
-	pm, err := NewProtocolManager(e.BlockChain().Config(), nil, light.DefaultServerIndexerConfig, config.UltraLightServers, config.UltraLightFraction, false, config.NetworkId, e.EventMux(), newPeerSet(), e.BlockChain(), e.TxPool(), e.ChainDb(), nil, nil, registrar, quitSync, new(sync.WaitGroup), e.Synced)
-	if err != nil {
-		return nil, err
-	}
-	srv.protocolManager = pm
-	pm.servingQueue = newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100)
-	pm.server = srv
-
 	return srv, nil
 	return srv, nil
 }
 }
 
 
@@ -120,102 +130,29 @@ func (s *LesServer) APIs() []rpc.API {
 		{
 		{
 			Namespace: "les",
 			Namespace: "les",
 			Version:   "1.0",
 			Version:   "1.0",
-			Service:   NewPrivateLightAPI(&s.lesCommons, s.protocolManager.reg),
+			Service:   NewPrivateLightAPI(&s.lesCommons),
 			Public:    false,
 			Public:    false,
 		},
 		},
 	}
 	}
 }
 }
 
 
-// startEventLoop starts an event handler loop that updates the recharge curve of
-// the client manager and adjusts the client pool's size according to the total
-// capacity updates coming from the client manager
-func (s *LesServer) startEventLoop() {
-	s.protocolManager.wg.Add(1)
-
-	var (
-		processing, procLast bool
-		procStarted          time.Time
-	)
-	blockProcFeed := make(chan bool, 100)
-	s.protocolManager.blockchain.(*core.BlockChain).SubscribeBlockProcessingEvent(blockProcFeed)
-	totalRechargeCh := make(chan uint64, 100)
-	totalRecharge := s.costTracker.subscribeTotalRecharge(totalRechargeCh)
-	totalCapacityCh := make(chan uint64, 100)
-	updateRecharge := func() {
-		if processing {
-			if !procLast {
-				procStarted = time.Now()
-			}
-			s.protocolManager.servingQueue.setThreads(s.thcBlockProcessing)
-			s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge, totalRecharge}})
-		} else {
-			if procLast {
-				blockProcessingTimer.UpdateSince(procStarted)
-			}
-			s.protocolManager.servingQueue.setThreads(s.thcNormal)
-			s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge / 16, totalRecharge / 2}, {totalRecharge / 2, totalRecharge / 2}, {totalRecharge, totalRecharge}})
-		}
-		procLast = processing
-	}
-	updateRecharge()
-	totalCapacity := s.fcManager.SubscribeTotalCapacity(totalCapacityCh)
-	s.clientPool.setLimits(s.maxPeers, totalCapacity)
-
-	var maxFreePeers uint64
-	go func() {
-		for {
-			select {
-			case processing = <-blockProcFeed:
-				updateRecharge()
-			case totalRecharge = <-totalRechargeCh:
-				updateRecharge()
-			case totalCapacity = <-totalCapacityCh:
-				totalCapacityGauge.Update(int64(totalCapacity))
-				newFreePeers := totalCapacity / s.freeClientCap
-				if newFreePeers < maxFreePeers && newFreePeers < uint64(s.maxPeers) {
-					log.Warn("Reduced total capacity", "maxFreePeers", newFreePeers)
-				}
-				maxFreePeers = newFreePeers
-				s.clientPool.setLimits(s.maxPeers, totalCapacity)
-			case <-s.protocolManager.quitSync:
-				s.protocolManager.wg.Done()
-				return
-			}
-		}
-	}()
-}
-
 func (s *LesServer) Protocols() []p2p.Protocol {
 func (s *LesServer) Protocols() []p2p.Protocol {
-	return s.makeProtocols(ServerProtocolVersions)
+	return s.makeProtocols(ServerProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} {
+		if p := s.peers.Peer(peerIdToString(id)); p != nil {
+			return p.Info()
+		}
+		return nil
+	})
 }
 }
 
 
 // Start starts the LES server
 // Start starts the LES server
 func (s *LesServer) Start(srvr *p2p.Server) {
 func (s *LesServer) Start(srvr *p2p.Server) {
-	s.maxPeers = s.config.LightPeers
-	totalRecharge := s.costTracker.totalRecharge()
-	if s.maxPeers > 0 {
-		s.freeClientCap = s.minCapacity //totalRecharge / uint64(s.maxPeers)
-		if s.freeClientCap < s.minCapacity {
-			s.freeClientCap = s.minCapacity
-		}
-		if s.freeClientCap > 0 {
-			s.defParams = flowcontrol.ServerParams{
-				BufLimit:    s.freeClientCap * bufLimitRatio,
-				MinRecharge: s.freeClientCap,
-			}
-		}
-	}
+	s.privateKey = srvr.PrivateKey
+	s.handler.start()
+
+	s.wg.Add(1)
+	go s.capacityManagement()
 
 
-	s.maxCapacity = s.freeClientCap * uint64(s.maxPeers)
-	if totalRecharge > s.maxCapacity {
-		s.maxCapacity = totalRecharge
-	}
-	s.fcManager.SetCapacityLimits(s.freeClientCap, s.maxCapacity, s.freeClientCap*2)
-	s.clientPool = newClientPool(s.chainDb, s.freeClientCap, 10000, mclock.System{}, func(id enode.ID) { go s.protocolManager.removePeer(peerIdToString(id)) })
-	s.clientPool.setPriceFactors(priceFactors{0, 1, 1}, priceFactors{0, 1, 1})
-	s.protocolManager.peers.notify(s.clientPool)
-	s.startEventLoop()
-	s.protocolManager.Start(s.config.LightPeers)
 	if srvr.DiscV5 != nil {
 	if srvr.DiscV5 != nil {
 		for _, topic := range s.lesTopics {
 		for _, topic := range s.lesTopics {
 			topic := topic
 			topic := topic
@@ -224,12 +161,32 @@ func (s *LesServer) Start(srvr *p2p.Server) {
 				logger.Info("Starting topic registration")
 				logger.Info("Starting topic registration")
 				defer logger.Info("Terminated topic registration")
 				defer logger.Info("Terminated topic registration")
 
 
-				srvr.DiscV5.RegisterTopic(topic, s.quitSync)
+				srvr.DiscV5.RegisterTopic(topic, s.closeCh)
 			}()
 			}()
 		}
 		}
 	}
 	}
-	s.privateKey = srvr.PrivateKey
-	s.protocolManager.blockLoop()
+}
+
+// Stop stops the LES service
+func (s *LesServer) Stop() {
+	close(s.closeCh)
+
+	// Disconnect existing sessions.
+	// This also closes the gate for any new registrations on the peer set.
+	// sessions which are already established but not added to pm.peers yet
+	// will exit when they try to register.
+	s.peers.Close()
+
+	s.fcManager.Stop()
+	s.clientPool.stop()
+	s.costTracker.stop()
+	s.handler.stop()
+	s.servingQueue.stop()
+
+	// Note, bloom trie indexer is closed by parent bloombits indexer.
+	s.chtIndexer.Close()
+	s.wg.Wait()
+	log.Info("Les server stopped")
 }
 }
 
 
 func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) {
 func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) {
@@ -238,78 +195,67 @@ func (s *LesServer) SetBloomBitsIndexer(bloomIndexer *core.ChainIndexer) {
 
 
 // SetClient sets the rpc client and starts running checkpoint contract if it is not yet watched.
 // SetClient sets the rpc client and starts running checkpoint contract if it is not yet watched.
 func (s *LesServer) SetContractBackend(backend bind.ContractBackend) {
 func (s *LesServer) SetContractBackend(backend bind.ContractBackend) {
-	if s.protocolManager.reg != nil {
-		s.protocolManager.reg.start(backend)
+	if s.oracle == nil {
+		return
 	}
 	}
+	s.oracle.start(backend)
 }
 }
 
 
-// Stop stops the LES service
-func (s *LesServer) Stop() {
-	s.fcManager.Stop()
-	s.chtIndexer.Close()
-	// bloom trie indexer is closed by parent bloombits indexer
-	go func() {
-		<-s.protocolManager.noMorePeers
-	}()
-	s.clientPool.stop()
-	s.costTracker.stop()
-	s.protocolManager.Stop()
-}
+// capacityManagement starts an event handler loop that updates the recharge curve of
+// the client manager and adjusts the client pool's size according to the total
+// capacity updates coming from the client manager
+func (s *LesServer) capacityManagement() {
+	defer s.wg.Done()
 
 
-// todo(rjl493456442) separate client and server implementation.
-func (pm *ProtocolManager) blockLoop() {
-	pm.wg.Add(1)
-	headCh := make(chan core.ChainHeadEvent, 10)
-	headSub := pm.blockchain.SubscribeChainHeadEvent(headCh)
-	go func() {
-		var lastHead *types.Header
-		lastBroadcastTd := common.Big0
-		for {
-			select {
-			case ev := <-headCh:
-				peers := pm.peers.AllPeers()
-				if len(peers) > 0 {
-					header := ev.Block.Header()
-					hash := header.Hash()
-					number := header.Number.Uint64()
-					td := rawdb.ReadTd(pm.chainDb, hash, number)
-					if td != nil && td.Cmp(lastBroadcastTd) > 0 {
-						var reorg uint64
-						if lastHead != nil {
-							reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64()
-						}
-						lastHead = header
-						lastBroadcastTd = td
+	processCh := make(chan bool, 100)
+	sub := s.handler.blockchain.SubscribeBlockProcessingEvent(processCh)
+	defer sub.Unsubscribe()
 
 
-						log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
+	totalRechargeCh := make(chan uint64, 100)
+	totalRecharge := s.costTracker.subscribeTotalRecharge(totalRechargeCh)
 
 
-						announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
-						var (
-							signed         bool
-							signedAnnounce announceData
-						)
+	totalCapacityCh := make(chan uint64, 100)
+	totalCapacity := s.fcManager.SubscribeTotalCapacity(totalCapacityCh)
+	s.clientPool.setLimits(s.config.LightPeers, totalCapacity)
 
 
-						for _, p := range peers {
-							p := p
-							switch p.announceType {
-							case announceTypeSimple:
-								p.queueSend(func() { p.SendAnnounce(announce) })
-							case announceTypeSigned:
-								if !signed {
-									signedAnnounce = announce
-									signedAnnounce.sign(pm.server.privateKey)
-									signed = true
-								}
-								p.queueSend(func() { p.SendAnnounce(signedAnnounce) })
-							}
-						}
-					}
-				}
-			case <-pm.quitSync:
-				headSub.Unsubscribe()
-				pm.wg.Done()
-				return
+	var (
+		busy         bool
+		freePeers    uint64
+		blockProcess mclock.AbsTime
+	)
+	updateRecharge := func() {
+		if busy {
+			s.servingQueue.setThreads(s.threadsBusy)
+			s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge, totalRecharge}})
+		} else {
+			s.servingQueue.setThreads(s.threadsIdle)
+			s.fcManager.SetRechargeCurve(flowcontrol.PieceWiseLinear{{0, 0}, {totalRecharge / 10, totalRecharge}, {totalRecharge, totalRecharge}})
+		}
+	}
+	updateRecharge()
+
+	for {
+		select {
+		case busy = <-processCh:
+			if busy {
+				blockProcess = mclock.Now()
+			} else {
+				blockProcessingTimer.Update(time.Duration(mclock.Now() - blockProcess))
 			}
 			}
+			updateRecharge()
+		case totalRecharge = <-totalRechargeCh:
+			totalRechargeGauge.Update(int64(totalRecharge))
+			updateRecharge()
+		case totalCapacity = <-totalCapacityCh:
+			totalCapacityGauge.Update(int64(totalCapacity))
+			newFreePeers := totalCapacity / s.freeCapacity
+			if newFreePeers < freePeers && newFreePeers < uint64(s.config.LightPeers) {
+				log.Warn("Reduced free peer connections", "from", freePeers, "to", newFreePeers)
+			}
+			freePeers = newFreePeers
+			s.clientPool.setLimits(s.config.LightPeers, totalCapacity)
+		case <-s.closeCh:
+			return
 		}
 		}
-	}()
+	}
 }
 }

+ 921 - 0
les/server_handler.go

@@ -0,0 +1,921 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package les
+
+import (
+	"encoding/binary"
+	"encoding/json"
+	"errors"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/common/mclock"
+	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/core/state"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/ethdb"
+	"github.com/ethereum/go-ethereum/light"
+	"github.com/ethereum/go-ethereum/log"
+	"github.com/ethereum/go-ethereum/metrics"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/rlp"
+	"github.com/ethereum/go-ethereum/trie"
+)
+
+const (
+	softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data.
+	estHeaderRlpSize  = 500             // Approximate size of an RLP encoded block header
+	ethVersion        = 63              // equivalent eth version for the downloader
+
+	MaxHeaderFetch           = 192 // Amount of block headers to be fetched per retrieval request
+	MaxBodyFetch             = 32  // Amount of block bodies to be fetched per retrieval request
+	MaxReceiptFetch          = 128 // Amount of transaction receipts to allow fetching per request
+	MaxCodeFetch             = 64  // Amount of contract codes to allow fetching per request
+	MaxProofsFetch           = 64  // Amount of merkle proofs to be fetched per retrieval request
+	MaxHelperTrieProofsFetch = 64  // Amount of helper tries to be fetched per retrieval request
+	MaxTxSend                = 64  // Amount of transactions to be send per request
+	MaxTxStatus              = 256 // Amount of transactions to queried per request
+)
+
+var errTooManyInvalidRequest = errors.New("too many invalid requests made")
+
+// serverHandler is responsible for serving light client and process
+// all incoming light requests.
+type serverHandler struct {
+	blockchain *core.BlockChain
+	chainDb    ethdb.Database
+	txpool     *core.TxPool
+	server     *LesServer
+
+	closeCh chan struct{}  // Channel used to exit all background routines of handler.
+	wg      sync.WaitGroup // WaitGroup used to track all background routines of handler.
+	synced  func() bool    // Callback function used to determine whether local node is synced.
+
+	// Testing fields
+	addTxsSync bool
+}
+
+func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler {
+	handler := &serverHandler{
+		server:     server,
+		blockchain: blockchain,
+		chainDb:    chainDb,
+		txpool:     txpool,
+		closeCh:    make(chan struct{}),
+		synced:     synced,
+	}
+	return handler
+}
+
+// start starts the server handler.
+func (h *serverHandler) start() {
+	h.wg.Add(1)
+	go h.broadcastHeaders()
+}
+
+// stop stops the server handler.
+func (h *serverHandler) stop() {
+	close(h.closeCh)
+	h.wg.Wait()
+}
+
+// runPeer is the p2p protocol run function for the given version.
+func (h *serverHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error {
+	peer := newPeer(int(version), h.server.config.NetworkId, false, p, newMeteredMsgWriter(rw, int(version)))
+	h.wg.Add(1)
+	defer h.wg.Done()
+	return h.handle(peer)
+}
+
+func (h *serverHandler) handle(p *peer) error {
+	// Reject light clients if server is not synced.
+	if !h.synced() {
+		return p2p.DiscRequested
+	}
+	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
+
+	// Execute the LES handshake
+	var (
+		head   = h.blockchain.CurrentHeader()
+		hash   = head.Hash()
+		number = head.Number.Uint64()
+		td     = h.blockchain.GetTd(hash, number)
+	)
+	if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil {
+		p.Log().Debug("Light Ethereum handshake failed", "err", err)
+		return err
+	}
+	defer p.fcClient.Disconnect()
+
+	// Register the peer locally
+	if err := h.server.peers.Register(p); err != nil {
+		p.Log().Error("Light Ethereum peer registration failed", "err", err)
+		return err
+	}
+	clientConnectionGauge.Update(int64(h.server.peers.Len()))
+
+	// add dummy balance tracker for tests
+	if p.balanceTracker == nil {
+		p.balanceTracker = &balanceTracker{}
+		p.balanceTracker.init(&mclock.System{}, 1)
+	}
+
+	connectedAt := mclock.Now()
+	defer func() {
+		p.balanceTracker = nil
+		h.server.peers.Unregister(p.id)
+		clientConnectionGauge.Update(int64(h.server.peers.Len()))
+		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
+	}()
+
+	// Spawn a main loop to handle all incoming messages.
+	for {
+		select {
+		case err := <-p.errCh:
+			p.Log().Debug("Failed to send light ethereum response", "err", err)
+			return err
+		default:
+		}
+		if err := h.handleMsg(p); err != nil {
+			p.Log().Debug("Light Ethereum message handling failed", "err", err)
+			return err
+		}
+	}
+}
+
+// handleMsg is invoked whenever an inbound message is received from a remote
+// peer. The remote connection is torn down upon returning any error.
+func (h *serverHandler) handleMsg(p *peer) error {
+	// Read the next message from the remote peer, and ensure it's fully consumed
+	msg, err := p.rw.ReadMsg()
+	if err != nil {
+		return err
+	}
+	p.Log().Trace("Light Ethereum message arrived", "code", msg.Code, "bytes", msg.Size)
+
+	// Discard large message which exceeds the limitation.
+	if msg.Size > ProtocolMaxMsgSize {
+		clientErrorMeter.Mark(1)
+		return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
+	}
+	defer msg.Discard()
+
+	var (
+		maxCost uint64
+		task    *servingTask
+	)
+	p.responseCount++
+	responseCount := p.responseCount
+	// accept returns an indicator whether the request can be served.
+	// If so, deduct the max cost from the flow control buffer.
+	accept := func(reqID, reqCnt, maxCnt uint64) bool {
+		// Short circuit if the peer is already frozen or the request is invalid.
+		inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
+		if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt {
+			p.fcClient.OneTimeCost(inSizeCost)
+			return false
+		}
+		// Prepaid max cost units before request been serving.
+		maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt)
+		accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
+		if !accepted {
+			p.freezeClient()
+			p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
+			p.fcClient.OneTimeCost(inSizeCost)
+			return false
+		}
+		// Create a multi-stage task, estimate the time it takes for the task to
+		// execute, and cache it in the request service queue.
+		factor := h.server.costTracker.globalFactor()
+		if factor < 0.001 {
+			factor = 1
+			p.Log().Error("Invalid global cost factor", "factor", factor)
+		}
+		maxTime := uint64(float64(maxCost) / factor)
+		task = h.server.servingQueue.newTask(p, maxTime, priority)
+		if task.start() {
+			return true
+		}
+		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
+		return false
+	}
+	// sendResponse sends back the response and updates the flow control statistic.
+	sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) {
+		p.responseLock.Lock()
+		defer p.responseLock.Unlock()
+
+		// Short circuit if the client is already frozen.
+		if p.isFrozen() {
+			realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0)
+			p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
+			return
+		}
+		// Positive correction buffer value with real cost.
+		var replySize uint32
+		if reply != nil {
+			replySize = reply.size()
+		}
+		var realCost uint64
+		if h.server.costTracker.testing {
+			realCost = maxCost // Assign a fake cost for testing purpose
+		} else {
+			realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize)
+		}
+		bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
+		if amount != 0 {
+			// Feed cost tracker request serving statistic.
+			h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
+			// Reduce priority "balance" for the specific peer.
+			p.balanceTracker.requestCost(realCost)
+		}
+		if reply != nil {
+			p.queueSend(func() {
+				if err := reply.send(bv); err != nil {
+					select {
+					case p.errCh <- err:
+					default:
+					}
+				}
+			})
+		}
+	}
+	switch msg.Code {
+	case GetBlockHeadersMsg:
+		p.Log().Trace("Received block header request")
+		if metrics.EnabledExpensive {
+			miscInHeaderPacketsMeter.Mark(1)
+			miscInHeaderTrafficMeter.Mark(int64(msg.Size))
+		}
+		var req struct {
+			ReqID uint64
+			Query getBlockHeadersData
+		}
+		if err := msg.Decode(&req); err != nil {
+			clientErrorMeter.Mark(1)
+			return errResp(ErrDecode, "%v: %v", msg, err)
+		}
+		query := req.Query
+		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
+			go func() {
+				hashMode := query.Origin.Hash != (common.Hash{})
+				first := true
+				maxNonCanonical := uint64(100)
+
+				// Gather headers until the fetch or network limits is reached
+				var (
+					bytes   common.StorageSize
+					headers []*types.Header
+					unknown bool
+				)
+				for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit {
+					if !first && !task.waitOrStop() {
+						sendResponse(req.ReqID, 0, nil, task.servingTime)
+						return
+					}
+					// Retrieve the next header satisfying the query
+					var origin *types.Header
+					if hashMode {
+						if first {
+							origin = h.blockchain.GetHeaderByHash(query.Origin.Hash)
+							if origin != nil {
+								query.Origin.Number = origin.Number.Uint64()
+							}
+						} else {
+							origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
+						}
+					} else {
+						origin = h.blockchain.GetHeaderByNumber(query.Origin.Number)
+					}
+					if origin == nil {
+						atomic.AddUint32(&p.invalidCount, 1)
+						break
+					}
+					headers = append(headers, origin)
+					bytes += estHeaderRlpSize
+
+					// Advance to the next header of the query
+					switch {
+					case hashMode && query.Reverse:
+						// Hash based traversal towards the genesis block
+						ancestor := query.Skip + 1
+						if ancestor == 0 {
+							unknown = true
+						} else {
+							query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
+							unknown = query.Origin.Hash == common.Hash{}
+						}
+					case hashMode && !query.Reverse:
+						// Hash based traversal towards the leaf block
+						var (
+							current = origin.Number.Uint64()
+							next    = current + query.Skip + 1
+						)
+						if next <= current {
+							infos, _ := json.MarshalIndent(p.Peer.Info(), "", "  ")
+							p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos)
+							unknown = true
+						} else {
+							if header := h.blockchain.GetHeaderByNumber(next); header != nil {
+								nextHash := header.Hash()
+								expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
+								if expOldHash == query.Origin.Hash {
+									query.Origin.Hash, query.Origin.Number = nextHash, next
+								} else {
+									unknown = true
+								}
+							} else {
+								unknown = true
+							}
+						}
+					case query.Reverse:
+						// Number based traversal towards the genesis block
+						if query.Origin.Number >= query.Skip+1 {
+							query.Origin.Number -= query.Skip + 1
+						} else {
+							unknown = true
+						}
+
+					case !query.Reverse:
+						// Number based traversal towards the leaf block
+						query.Origin.Number += query.Skip + 1
+					}
+					first = false
+				}
+				reply := p.ReplyBlockHeaders(req.ReqID, headers)
+				sendResponse(req.ReqID, query.Amount, p.ReplyBlockHeaders(req.ReqID, headers), task.done())
+				if metrics.EnabledExpensive {
+					miscOutHeaderPacketsMeter.Mark(1)
+					miscOutHeaderTrafficMeter.Mark(int64(reply.size()))
+				}
+			}()
+		}
+
+	case GetBlockBodiesMsg:
+		p.Log().Trace("Received block bodies request")
+		if metrics.EnabledExpensive {
+			miscInBodyPacketsMeter.Mark(1)
+			miscInBodyTrafficMeter.Mark(int64(msg.Size))
+		}
+		var req struct {
+			ReqID  uint64
+			Hashes []common.Hash
+		}
+		if err := msg.Decode(&req); err != nil {
+			clientErrorMeter.Mark(1)
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		var (
+			bytes  int
+			bodies []rlp.RawValue
+		)
+		reqCnt := len(req.Hashes)
+		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
+			go func() {
+				for i, hash := range req.Hashes {
+					if i != 0 && !task.waitOrStop() {
+						sendResponse(req.ReqID, 0, nil, task.servingTime)
+						return
+					}
+					if bytes >= softResponseLimit {
+						break
+					}
+					body := h.blockchain.GetBodyRLP(hash)
+					if body == nil {
+						atomic.AddUint32(&p.invalidCount, 1)
+						continue
+					}
+					bodies = append(bodies, body)
+					bytes += len(body)
+				}
+				reply := p.ReplyBlockBodiesRLP(req.ReqID, bodies)
+				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
+				if metrics.EnabledExpensive {
+					miscOutBodyPacketsMeter.Mark(1)
+					miscOutBodyTrafficMeter.Mark(int64(reply.size()))
+				}
+			}()
+		}
+
+	case GetCodeMsg:
+		p.Log().Trace("Received code request")
+		if metrics.EnabledExpensive {
+			miscInCodePacketsMeter.Mark(1)
+			miscInCodeTrafficMeter.Mark(int64(msg.Size))
+		}
+		var req struct {
+			ReqID uint64
+			Reqs  []CodeReq
+		}
+		if err := msg.Decode(&req); err != nil {
+			clientErrorMeter.Mark(1)
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		var (
+			bytes int
+			data  [][]byte
+		)
+		reqCnt := len(req.Reqs)
+		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
+			go func() {
+				for i, request := range req.Reqs {
+					if i != 0 && !task.waitOrStop() {
+						sendResponse(req.ReqID, 0, nil, task.servingTime)
+						return
+					}
+					// Look up the root hash belonging to the request
+					header := h.blockchain.GetHeaderByHash(request.BHash)
+					if header == nil {
+						p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash)
+						atomic.AddUint32(&p.invalidCount, 1)
+						continue
+					}
+					// Refuse to search stale state data in the database since looking for
+					// a non-exist key is kind of expensive.
+					local := h.blockchain.CurrentHeader().Number.Uint64()
+					if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
+						p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local)
+						atomic.AddUint32(&p.invalidCount, 1)
+						continue
+					}
+					triedb := h.blockchain.StateCache().TrieDB()
+
+					account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
+					if err != nil {
+						p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
+						atomic.AddUint32(&p.invalidCount, 1)
+						continue
+					}
+					code, err := triedb.Node(common.BytesToHash(account.CodeHash))
+					if err != nil {
+						p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
+						continue
+					}
+					// Accumulate the code and abort if enough data was retrieved
+					data = append(data, code)
+					if bytes += len(code); bytes >= softResponseLimit {
+						break
+					}
+				}
+				reply := p.ReplyCode(req.ReqID, data)
+				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
+				if metrics.EnabledExpensive {
+					miscOutCodePacketsMeter.Mark(1)
+					miscOutCodeTrafficMeter.Mark(int64(reply.size()))
+				}
+			}()
+		}
+
+	case GetReceiptsMsg:
+		p.Log().Trace("Received receipts request")
+		if metrics.EnabledExpensive {
+			miscInReceiptPacketsMeter.Mark(1)
+			miscInReceiptTrafficMeter.Mark(int64(msg.Size))
+		}
+		var req struct {
+			ReqID  uint64
+			Hashes []common.Hash
+		}
+		if err := msg.Decode(&req); err != nil {
+			clientErrorMeter.Mark(1)
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		var (
+			bytes    int
+			receipts []rlp.RawValue
+		)
+		reqCnt := len(req.Hashes)
+		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
+			go func() {
+				for i, hash := range req.Hashes {
+					if i != 0 && !task.waitOrStop() {
+						sendResponse(req.ReqID, 0, nil, task.servingTime)
+						return
+					}
+					if bytes >= softResponseLimit {
+						break
+					}
+					// Retrieve the requested block's receipts, skipping if unknown to us
+					results := h.blockchain.GetReceiptsByHash(hash)
+					if results == nil {
+						if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
+							atomic.AddUint32(&p.invalidCount, 1)
+							continue
+						}
+					}
+					// If known, encode and queue for response packet
+					if encoded, err := rlp.EncodeToBytes(results); err != nil {
+						log.Error("Failed to encode receipt", "err", err)
+					} else {
+						receipts = append(receipts, encoded)
+						bytes += len(encoded)
+					}
+				}
+				reply := p.ReplyReceiptsRLP(req.ReqID, receipts)
+				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
+				if metrics.EnabledExpensive {
+					miscOutReceiptPacketsMeter.Mark(1)
+					miscOutReceiptTrafficMeter.Mark(int64(reply.size()))
+				}
+			}()
+		}
+
+	case GetProofsV2Msg:
+		p.Log().Trace("Received les/2 proofs request")
+		if metrics.EnabledExpensive {
+			miscInTrieProofPacketsMeter.Mark(1)
+			miscInTrieProofTrafficMeter.Mark(int64(msg.Size))
+		}
+		var req struct {
+			ReqID uint64
+			Reqs  []ProofReq
+		}
+		if err := msg.Decode(&req); err != nil {
+			clientErrorMeter.Mark(1)
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		// Gather state data until the fetch or network limits is reached
+		var (
+			lastBHash common.Hash
+			root      common.Hash
+		)
+		reqCnt := len(req.Reqs)
+		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
+			go func() {
+				nodes := light.NewNodeSet()
+
+				for i, request := range req.Reqs {
+					if i != 0 && !task.waitOrStop() {
+						sendResponse(req.ReqID, 0, nil, task.servingTime)
+						return
+					}
+					// Look up the root hash belonging to the request
+					var (
+						number *uint64
+						header *types.Header
+						trie   state.Trie
+					)
+					if request.BHash != lastBHash {
+						root, lastBHash = common.Hash{}, request.BHash
+
+						if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil {
+							p.Log().Warn("Failed to retrieve header for proof", "block", *number, "hash", request.BHash)
+							atomic.AddUint32(&p.invalidCount, 1)
+							continue
+						}
+						// Refuse to search stale state data in the database since looking for
+						// a non-exist key is kind of expensive.
+						local := h.blockchain.CurrentHeader().Number.Uint64()
+						if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local {
+							p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local)
+							atomic.AddUint32(&p.invalidCount, 1)
+							continue
+						}
+						root = header.Root
+					}
+					// If a header lookup failed (non existent), ignore subsequent requests for the same header
+					if root == (common.Hash{}) {
+						atomic.AddUint32(&p.invalidCount, 1)
+						continue
+					}
+					// Open the account or storage trie for the request
+					statedb := h.blockchain.StateCache()
+
+					switch len(request.AccKey) {
+					case 0:
+						// No account key specified, open an account trie
+						trie, err = statedb.OpenTrie(root)
+						if trie == nil || err != nil {
+							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err)
+							continue
+						}
+					default:
+						// Account key specified, open a storage trie
+						account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
+						if err != nil {
+							p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
+							atomic.AddUint32(&p.invalidCount, 1)
+							continue
+						}
+						trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root)
+						if trie == nil || err != nil {
+							p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
+							continue
+						}
+					}
+					// Prove the user's request from the account or stroage trie
+					if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil {
+						p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err)
+						continue
+					}
+					if nodes.DataSize() >= softResponseLimit {
+						break
+					}
+				}
+				reply := p.ReplyProofsV2(req.ReqID, nodes.NodeList())
+				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
+				if metrics.EnabledExpensive {
+					miscOutTrieProofPacketsMeter.Mark(1)
+					miscOutTrieProofTrafficMeter.Mark(int64(reply.size()))
+				}
+			}()
+		}
+
+	case GetHelperTrieProofsMsg:
+		p.Log().Trace("Received helper trie proof request")
+		if metrics.EnabledExpensive {
+			miscInHelperTriePacketsMeter.Mark(1)
+			miscInHelperTrieTrafficMeter.Mark(int64(msg.Size))
+		}
+		var req struct {
+			ReqID uint64
+			Reqs  []HelperTrieReq
+		}
+		if err := msg.Decode(&req); err != nil {
+			clientErrorMeter.Mark(1)
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		// Gather state data until the fetch or network limits is reached
+		var (
+			auxBytes int
+			auxData  [][]byte
+		)
+		reqCnt := len(req.Reqs)
+		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
+			go func() {
+				var (
+					lastIdx  uint64
+					lastType uint
+					root     common.Hash
+					auxTrie  *trie.Trie
+				)
+				nodes := light.NewNodeSet()
+				for i, request := range req.Reqs {
+					if i != 0 && !task.waitOrStop() {
+						sendResponse(req.ReqID, 0, nil, task.servingTime)
+						return
+					}
+					if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx {
+						auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx
+
+						var prefix string
+						if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) {
+							auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix)))
+						}
+					}
+					if request.AuxReq == auxRoot {
+						var data []byte
+						if root != (common.Hash{}) {
+							data = root[:]
+						}
+						auxData = append(auxData, data)
+						auxBytes += len(data)
+					} else {
+						if auxTrie != nil {
+							auxTrie.Prove(request.Key, request.FromLevel, nodes)
+						}
+						if request.AuxReq != 0 {
+							data := h.getAuxiliaryHeaders(request)
+							auxData = append(auxData, data)
+							auxBytes += len(data)
+						}
+					}
+					if nodes.DataSize()+auxBytes >= softResponseLimit {
+						break
+					}
+				}
+				reply := p.ReplyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData})
+				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
+				if metrics.EnabledExpensive {
+					miscOutHelperTriePacketsMeter.Mark(1)
+					miscOutHelperTrieTrafficMeter.Mark(int64(reply.size()))
+				}
+			}()
+		}
+
+	case SendTxV2Msg:
+		p.Log().Trace("Received new transactions")
+		if metrics.EnabledExpensive {
+			miscInTxsPacketsMeter.Mark(1)
+			miscInTxsTrafficMeter.Mark(int64(msg.Size))
+		}
+		var req struct {
+			ReqID uint64
+			Txs   []*types.Transaction
+		}
+		if err := msg.Decode(&req); err != nil {
+			clientErrorMeter.Mark(1)
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		reqCnt := len(req.Txs)
+		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
+			go func() {
+				stats := make([]light.TxStatus, len(req.Txs))
+				for i, tx := range req.Txs {
+					if i != 0 && !task.waitOrStop() {
+						return
+					}
+					hash := tx.Hash()
+					stats[i] = h.txStatus(hash)
+					if stats[i].Status == core.TxStatusUnknown {
+						addFn := h.txpool.AddRemotes
+						// Add txs synchronously for testing purpose
+						if h.addTxsSync {
+							addFn = h.txpool.AddRemotesSync
+						}
+						if errs := addFn([]*types.Transaction{tx}); errs[0] != nil {
+							stats[i].Error = errs[0].Error()
+							continue
+						}
+						stats[i] = h.txStatus(hash)
+					}
+				}
+				reply := p.ReplyTxStatus(req.ReqID, stats)
+				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
+				if metrics.EnabledExpensive {
+					miscOutTxsPacketsMeter.Mark(1)
+					miscOutTxsTrafficMeter.Mark(int64(reply.size()))
+				}
+			}()
+		}
+
+	case GetTxStatusMsg:
+		p.Log().Trace("Received transaction status query request")
+		if metrics.EnabledExpensive {
+			miscInTxStatusPacketsMeter.Mark(1)
+			miscInTxStatusTrafficMeter.Mark(int64(msg.Size))
+		}
+		var req struct {
+			ReqID  uint64
+			Hashes []common.Hash
+		}
+		if err := msg.Decode(&req); err != nil {
+			clientErrorMeter.Mark(1)
+			return errResp(ErrDecode, "msg %v: %v", msg, err)
+		}
+		reqCnt := len(req.Hashes)
+		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
+			go func() {
+				stats := make([]light.TxStatus, len(req.Hashes))
+				for i, hash := range req.Hashes {
+					if i != 0 && !task.waitOrStop() {
+						sendResponse(req.ReqID, 0, nil, task.servingTime)
+						return
+					}
+					stats[i] = h.txStatus(hash)
+				}
+				reply := p.ReplyTxStatus(req.ReqID, stats)
+				sendResponse(req.ReqID, uint64(reqCnt), reply, task.done())
+				if metrics.EnabledExpensive {
+					miscOutTxStatusPacketsMeter.Mark(1)
+					miscOutTxStatusTrafficMeter.Mark(int64(reply.size()))
+				}
+			}()
+		}
+
+	default:
+		p.Log().Trace("Received invalid message", "code", msg.Code)
+		clientErrorMeter.Mark(1)
+		return errResp(ErrInvalidMsgCode, "%v", msg.Code)
+	}
+	// If the client has made too much invalid request(e.g. request a non-exist data),
+	// reject them to prevent SPAM attack.
+	if atomic.LoadUint32(&p.invalidCount) > maxRequestErrors {
+		clientErrorMeter.Mark(1)
+		return errTooManyInvalidRequest
+	}
+	return nil
+}
+
+// getAccount retrieves an account from the state based on root.
+func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) {
+	trie, err := trie.New(root, triedb)
+	if err != nil {
+		return state.Account{}, err
+	}
+	blob, err := trie.TryGet(hash[:])
+	if err != nil {
+		return state.Account{}, err
+	}
+	var account state.Account
+	if err = rlp.DecodeBytes(blob, &account); err != nil {
+		return state.Account{}, err
+	}
+	return account, nil
+}
+
+// getHelperTrie returns the post-processed trie root for the given trie ID and section index
+func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) {
+	switch typ {
+	case htCanonical:
+		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1)
+		return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix
+	case htBloomBits:
+		sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1)
+		return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix
+	}
+	return common.Hash{}, ""
+}
+
+// getAuxiliaryHeaders returns requested auxiliary headers for the CHT request.
+func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte {
+	if req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8 {
+		blockNum := binary.BigEndian.Uint64(req.Key)
+		hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum)
+		return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum)
+	}
+	return nil
+}
+
+// txStatus returns the status of a specified transaction.
+func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
+	var stat light.TxStatus
+	// Looking the transaction in txpool first.
+	stat.Status = h.txpool.Status([]common.Hash{hash})[0]
+
+	// If the transaction is unknown to the pool, try looking it up locally.
+	if stat.Status == core.TxStatusUnknown {
+		lookup := h.blockchain.GetTransactionLookup(hash)
+		if lookup != nil {
+			stat.Status = core.TxStatusIncluded
+			stat.Lookup = lookup
+		}
+	}
+	return stat
+}
+
+// broadcastHeaders broadcasts new block information to all connected light
+// clients. According to the agreement between client and server, server should
+// only broadcast new announcement if the total difficulty is higher than the
+// last one. Besides server will add the signature if client requires.
+func (h *serverHandler) broadcastHeaders() {
+	defer h.wg.Done()
+
+	headCh := make(chan core.ChainHeadEvent, 10)
+	headSub := h.blockchain.SubscribeChainHeadEvent(headCh)
+	defer headSub.Unsubscribe()
+
+	var (
+		lastHead *types.Header
+		lastTd   = common.Big0
+	)
+	for {
+		select {
+		case ev := <-headCh:
+			peers := h.server.peers.AllPeers()
+			if len(peers) == 0 {
+				continue
+			}
+			header := ev.Block.Header()
+			hash, number := header.Hash(), header.Number.Uint64()
+			td := h.blockchain.GetTd(hash, number)
+			if td == nil || td.Cmp(lastTd) <= 0 {
+				continue
+			}
+			var reorg uint64
+			if lastHead != nil {
+				reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
+			}
+			lastHead, lastTd = header, td
+
+			log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
+			var (
+				signed         bool
+				signedAnnounce announceData
+			)
+			announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
+			for _, p := range peers {
+				p := p
+				switch p.announceType {
+				case announceTypeSimple:
+					p.queueSend(func() { p.SendAnnounce(announce) })
+				case announceTypeSigned:
+					if !signed {
+						signedAnnounce = announce
+						signedAnnounce.sign(h.server.privateKey)
+						signed = true
+					}
+					p.queueSend(func() { p.SendAnnounce(signedAnnounce) })
+				}
+			}
+		case <-h.closeCh:
+			return
+		}
+	}
+}

+ 36 - 25
les/serverpool.go

@@ -115,8 +115,6 @@ type serverPool struct {
 	db     ethdb.Database
 	db     ethdb.Database
 	dbKey  []byte
 	dbKey  []byte
 	server *p2p.Server
 	server *p2p.Server
-	quit   chan struct{}
-	wg     *sync.WaitGroup
 	connWg sync.WaitGroup
 	connWg sync.WaitGroup
 
 
 	topic discv5.Topic
 	topic discv5.Topic
@@ -137,14 +135,15 @@ type serverPool struct {
 	connCh                     chan *connReq
 	connCh                     chan *connReq
 	disconnCh                  chan *disconnReq
 	disconnCh                  chan *disconnReq
 	registerCh                 chan *registerReq
 	registerCh                 chan *registerReq
+
+	closeCh chan struct{}
+	wg      sync.WaitGroup
 }
 }
 
 
 // newServerPool creates a new serverPool instance
 // newServerPool creates a new serverPool instance
-func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, trustedNodes []string) *serverPool {
+func newServerPool(db ethdb.Database, ulcServers []string) *serverPool {
 	pool := &serverPool{
 	pool := &serverPool{
 		db:           db,
 		db:           db,
-		quit:         quit,
-		wg:           wg,
 		entries:      make(map[enode.ID]*poolEntry),
 		entries:      make(map[enode.ID]*poolEntry),
 		timeout:      make(chan *poolEntry, 1),
 		timeout:      make(chan *poolEntry, 1),
 		adjustStats:  make(chan poolStatAdjust, 100),
 		adjustStats:  make(chan poolStatAdjust, 100),
@@ -152,10 +151,11 @@ func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, tr
 		connCh:       make(chan *connReq),
 		connCh:       make(chan *connReq),
 		disconnCh:    make(chan *disconnReq),
 		disconnCh:    make(chan *disconnReq),
 		registerCh:   make(chan *registerReq),
 		registerCh:   make(chan *registerReq),
+		closeCh:      make(chan struct{}),
 		knownSelect:  newWeightedRandomSelect(),
 		knownSelect:  newWeightedRandomSelect(),
 		newSelect:    newWeightedRandomSelect(),
 		newSelect:    newWeightedRandomSelect(),
 		fastDiscover: true,
 		fastDiscover: true,
-		trustedNodes: parseTrustedNodes(trustedNodes),
+		trustedNodes: parseTrustedNodes(ulcServers),
 	}
 	}
 
 
 	pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry)
 	pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry)
@@ -167,7 +167,6 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
 	pool.server = server
 	pool.server = server
 	pool.topic = topic
 	pool.topic = topic
 	pool.dbKey = append([]byte("serverPool/"), []byte(topic)...)
 	pool.dbKey = append([]byte("serverPool/"), []byte(topic)...)
-	pool.wg.Add(1)
 	pool.loadNodes()
 	pool.loadNodes()
 	pool.connectToTrustedNodes()
 	pool.connectToTrustedNodes()
 
 
@@ -178,9 +177,15 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
 		go pool.discoverNodes()
 		go pool.discoverNodes()
 	}
 	}
 	pool.checkDial()
 	pool.checkDial()
+	pool.wg.Add(1)
 	go pool.eventLoop()
 	go pool.eventLoop()
 }
 }
 
 
+func (pool *serverPool) stop() {
+	close(pool.closeCh)
+	pool.wg.Wait()
+}
+
 // discoverNodes wraps SearchTopic, converting result nodes to enode.Node.
 // discoverNodes wraps SearchTopic, converting result nodes to enode.Node.
 func (pool *serverPool) discoverNodes() {
 func (pool *serverPool) discoverNodes() {
 	ch := make(chan *discv5.Node)
 	ch := make(chan *discv5.Node)
@@ -207,7 +212,7 @@ func (pool *serverPool) connect(p *peer, node *enode.Node) *poolEntry {
 	req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)}
 	req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)}
 	select {
 	select {
 	case pool.connCh <- req:
 	case pool.connCh <- req:
-	case <-pool.quit:
+	case <-pool.closeCh:
 		return nil
 		return nil
 	}
 	}
 	return <-req.result
 	return <-req.result
@@ -219,7 +224,7 @@ func (pool *serverPool) registered(entry *poolEntry) {
 	req := &registerReq{entry: entry, done: make(chan struct{})}
 	req := &registerReq{entry: entry, done: make(chan struct{})}
 	select {
 	select {
 	case pool.registerCh <- req:
 	case pool.registerCh <- req:
-	case <-pool.quit:
+	case <-pool.closeCh:
 		return
 		return
 	}
 	}
 	<-req.done
 	<-req.done
@@ -231,7 +236,7 @@ func (pool *serverPool) registered(entry *poolEntry) {
 func (pool *serverPool) disconnect(entry *poolEntry) {
 func (pool *serverPool) disconnect(entry *poolEntry) {
 	stopped := false
 	stopped := false
 	select {
 	select {
-	case <-pool.quit:
+	case <-pool.closeCh:
 		stopped = true
 		stopped = true
 	default:
 	default:
 	}
 	}
@@ -278,6 +283,7 @@ func (pool *serverPool) adjustResponseTime(entry *poolEntry, time time.Duration,
 
 
 // eventLoop handles pool events and mutex locking for all internal functions
 // eventLoop handles pool events and mutex locking for all internal functions
 func (pool *serverPool) eventLoop() {
 func (pool *serverPool) eventLoop() {
+	defer pool.wg.Done()
 	lookupCnt := 0
 	lookupCnt := 0
 	var convTime mclock.AbsTime
 	var convTime mclock.AbsTime
 	if pool.discSetPeriod != nil {
 	if pool.discSetPeriod != nil {
@@ -361,7 +367,7 @@ func (pool *serverPool) eventLoop() {
 		case req := <-pool.connCh:
 		case req := <-pool.connCh:
 			if pool.trustedNodes[req.p.ID()] != nil {
 			if pool.trustedNodes[req.p.ID()] != nil {
 				// ignore trusted nodes
 				// ignore trusted nodes
-				req.result <- nil
+				req.result <- &poolEntry{trusted: true}
 			} else {
 			} else {
 				// Handle peer connection requests.
 				// Handle peer connection requests.
 				entry := pool.entries[req.p.ID()]
 				entry := pool.entries[req.p.ID()]
@@ -389,6 +395,9 @@ func (pool *serverPool) eventLoop() {
 			}
 			}
 
 
 		case req := <-pool.registerCh:
 		case req := <-pool.registerCh:
+			if req.entry.trusted {
+				continue
+			}
 			// Handle peer registration requests.
 			// Handle peer registration requests.
 			entry := req.entry
 			entry := req.entry
 			entry.state = psRegistered
 			entry.state = psRegistered
@@ -402,10 +411,13 @@ func (pool *serverPool) eventLoop() {
 			close(req.done)
 			close(req.done)
 
 
 		case req := <-pool.disconnCh:
 		case req := <-pool.disconnCh:
+			if req.entry.trusted {
+				continue
+			}
 			// Handle peer disconnection requests.
 			// Handle peer disconnection requests.
 			disconnect(req, req.stopped)
 			disconnect(req, req.stopped)
 
 
-		case <-pool.quit:
+		case <-pool.closeCh:
 			if pool.discSetPeriod != nil {
 			if pool.discSetPeriod != nil {
 				close(pool.discSetPeriod)
 				close(pool.discSetPeriod)
 			}
 			}
@@ -421,7 +433,6 @@ func (pool *serverPool) eventLoop() {
 				disconnect(req, true)
 				disconnect(req, true)
 			}
 			}
 			pool.saveNodes()
 			pool.saveNodes()
-			pool.wg.Done()
 			return
 			return
 		}
 		}
 	}
 	}
@@ -549,10 +560,10 @@ func (pool *serverPool) setRetryDial(entry *poolEntry) {
 	entry.delayedRetry = true
 	entry.delayedRetry = true
 	go func() {
 	go func() {
 		select {
 		select {
-		case <-pool.quit:
+		case <-pool.closeCh:
 		case <-time.After(delay):
 		case <-time.After(delay):
 			select {
 			select {
-			case <-pool.quit:
+			case <-pool.closeCh:
 			case pool.enableRetry <- entry:
 			case pool.enableRetry <- entry:
 			}
 			}
 		}
 		}
@@ -618,10 +629,10 @@ func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) {
 	go func() {
 	go func() {
 		pool.server.AddPeer(entry.node)
 		pool.server.AddPeer(entry.node)
 		select {
 		select {
-		case <-pool.quit:
+		case <-pool.closeCh:
 		case <-time.After(dialTimeout):
 		case <-time.After(dialTimeout):
 			select {
 			select {
-			case <-pool.quit:
+			case <-pool.closeCh:
 			case pool.timeout <- entry:
 			case pool.timeout <- entry:
 			}
 			}
 		}
 		}
@@ -662,14 +673,14 @@ type poolEntry struct {
 	lastConnected, dialed *poolEntryAddress
 	lastConnected, dialed *poolEntryAddress
 	addrSelect            weightedRandomSelect
 	addrSelect            weightedRandomSelect
 
 
-	lastDiscovered              mclock.AbsTime
-	known, knownSelected        bool
-	connectStats, delayStats    poolStats
-	responseStats, timeoutStats poolStats
-	state                       int
-	regTime                     mclock.AbsTime
-	queueIdx                    int
-	removed                     bool
+	lastDiscovered                mclock.AbsTime
+	known, knownSelected, trusted bool
+	connectStats, delayStats      poolStats
+	responseStats, timeoutStats   poolStats
+	state                         int
+	regTime                       mclock.AbsTime
+	queueIdx                      int
+	removed                       bool
 
 
 	delayedRetry bool
 	delayedRetry bool
 	shortRetry   int
 	shortRetry   int

+ 21 - 50
les/sync.go

@@ -43,35 +43,6 @@ const (
 	checkpointSync
 	checkpointSync
 )
 )
 
 
-// syncer is responsible for periodically synchronising with the network, both
-// downloading hashes and blocks as well as handling the announcement handler.
-func (pm *ProtocolManager) syncer() {
-	// Start and ensure cleanup of sync mechanisms
-	//pm.fetcher.Start()
-	//defer pm.fetcher.Stop()
-	defer pm.downloader.Terminate()
-
-	// Wait for different events to fire synchronisation operations
-	//forceSync := time.Tick(forceSyncCycle)
-	for {
-		select {
-		case <-pm.newPeerCh:
-			/*			// Make sure we have peers to select from, then sync
-						if pm.peers.Len() < minDesiredPeerCount {
-							break
-						}
-						go pm.synchronise(pm.peers.BestPeer())
-			*/
-		/*case <-forceSync:
-		// Force a sync even if not enough peers are present
-		go pm.synchronise(pm.peers.BestPeer())
-		*/
-		case <-pm.noMorePeers:
-			return
-		}
-	}
-}
-
 // validateCheckpoint verifies the advertised checkpoint by peer is valid or not.
 // validateCheckpoint verifies the advertised checkpoint by peer is valid or not.
 //
 //
 // Each network has several hard-coded checkpoint signer addresses. Only the
 // Each network has several hard-coded checkpoint signer addresses. Only the
@@ -80,22 +51,22 @@ func (pm *ProtocolManager) syncer() {
 // In addition to the checkpoint registered in the registrar contract, there are
 // In addition to the checkpoint registered in the registrar contract, there are
 // several legacy hardcoded checkpoints in our codebase. These checkpoints are
 // several legacy hardcoded checkpoints in our codebase. These checkpoints are
 // also considered as valid.
 // also considered as valid.
-func (pm *ProtocolManager) validateCheckpoint(peer *peer) error {
+func (h *clientHandler) validateCheckpoint(peer *peer) error {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
 	defer cancel()
 	defer cancel()
 
 
 	// Fetch the block header corresponding to the checkpoint registration.
 	// Fetch the block header corresponding to the checkpoint registration.
 	cp := peer.checkpoint
 	cp := peer.checkpoint
-	header, err := light.GetUntrustedHeaderByNumber(ctx, pm.odr, peer.checkpointNumber, peer.id)
+	header, err := light.GetUntrustedHeaderByNumber(ctx, h.backend.odr, peer.checkpointNumber, peer.id)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	// Fetch block logs associated with the block header.
 	// Fetch block logs associated with the block header.
-	logs, err := light.GetUntrustedBlockLogs(ctx, pm.odr, header)
+	logs, err := light.GetUntrustedBlockLogs(ctx, h.backend.odr, header)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	events := pm.reg.contract.LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash())
+	events := h.backend.oracle.contract.LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash())
 	if len(events) == 0 {
 	if len(events) == 0 {
 		return errInvalidCheckpoint
 		return errInvalidCheckpoint
 	}
 	}
@@ -107,7 +78,7 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error {
 	for _, event := range events {
 	for _, event := range events {
 		signatures = append(signatures, append(event.R[:], append(event.S[:], event.V)...))
 		signatures = append(signatures, append(event.R[:], append(event.S[:], event.V)...))
 	}
 	}
-	valid, signers := pm.reg.verifySigners(index, hash, signatures)
+	valid, signers := h.backend.oracle.verifySigners(index, hash, signatures)
 	if !valid {
 	if !valid {
 		return errInvalidCheckpoint
 		return errInvalidCheckpoint
 	}
 	}
@@ -116,14 +87,14 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error {
 }
 }
 
 
 // synchronise tries to sync up our local chain with a remote peer.
 // synchronise tries to sync up our local chain with a remote peer.
-func (pm *ProtocolManager) synchronise(peer *peer) {
+func (h *clientHandler) synchronise(peer *peer) {
 	// Short circuit if the peer is nil.
 	// Short circuit if the peer is nil.
 	if peer == nil {
 	if peer == nil {
 		return
 		return
 	}
 	}
 	// Make sure the peer's TD is higher than our own.
 	// Make sure the peer's TD is higher than our own.
-	latest := pm.blockchain.CurrentHeader()
-	currentTd := rawdb.ReadTd(pm.chainDb, latest.Hash(), latest.Number.Uint64())
+	latest := h.backend.blockchain.CurrentHeader()
+	currentTd := rawdb.ReadTd(h.backend.chainDb, latest.Hash(), latest.Number.Uint64())
 	if currentTd != nil && peer.headBlockInfo().Td.Cmp(currentTd) < 0 {
 	if currentTd != nil && peer.headBlockInfo().Td.Cmp(currentTd) < 0 {
 		return
 		return
 	}
 	}
@@ -140,8 +111,8 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
 	//     => Use provided checkpoint
 	//     => Use provided checkpoint
 	var checkpoint = &peer.checkpoint
 	var checkpoint = &peer.checkpoint
 	var hardcoded bool
 	var hardcoded bool
-	if pm.checkpoint != nil && pm.checkpoint.SectionIndex >= peer.checkpoint.SectionIndex {
-		checkpoint = pm.checkpoint // Use the hardcoded one.
+	if h.checkpoint != nil && h.checkpoint.SectionIndex >= peer.checkpoint.SectionIndex {
+		checkpoint = h.checkpoint // Use the hardcoded one.
 		hardcoded = true
 		hardcoded = true
 	}
 	}
 	// Determine whether we should run checkpoint syncing or normal light syncing.
 	// Determine whether we should run checkpoint syncing or normal light syncing.
@@ -157,34 +128,34 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
 	case checkpoint.Empty():
 	case checkpoint.Empty():
 		mode = lightSync
 		mode = lightSync
 		log.Debug("Disable checkpoint syncing", "reason", "empty checkpoint")
 		log.Debug("Disable checkpoint syncing", "reason", "empty checkpoint")
-	case latest.Number.Uint64() >= (checkpoint.SectionIndex+1)*pm.iConfig.ChtSize-1:
+	case latest.Number.Uint64() >= (checkpoint.SectionIndex+1)*h.backend.iConfig.ChtSize-1:
 		mode = lightSync
 		mode = lightSync
 		log.Debug("Disable checkpoint syncing", "reason", "local chain beyond the checkpoint")
 		log.Debug("Disable checkpoint syncing", "reason", "local chain beyond the checkpoint")
 	case hardcoded:
 	case hardcoded:
 		mode = legacyCheckpointSync
 		mode = legacyCheckpointSync
 		log.Debug("Disable checkpoint syncing", "reason", "checkpoint is hardcoded")
 		log.Debug("Disable checkpoint syncing", "reason", "checkpoint is hardcoded")
-	case pm.reg == nil || !pm.reg.isRunning():
+	case h.backend.oracle == nil || !h.backend.oracle.isRunning():
 		mode = legacyCheckpointSync
 		mode = legacyCheckpointSync
 		log.Debug("Disable checkpoint syncing", "reason", "checkpoint syncing is not activated")
 		log.Debug("Disable checkpoint syncing", "reason", "checkpoint syncing is not activated")
 	}
 	}
 	// Notify testing framework if syncing has completed(for testing purpose).
 	// Notify testing framework if syncing has completed(for testing purpose).
 	defer func() {
 	defer func() {
-		if pm.reg != nil && pm.reg.syncDoneHook != nil {
-			pm.reg.syncDoneHook()
+		if h.backend.oracle != nil && h.backend.oracle.syncDoneHook != nil {
+			h.backend.oracle.syncDoneHook()
 		}
 		}
 	}()
 	}()
 	start := time.Now()
 	start := time.Now()
 	if mode == checkpointSync || mode == legacyCheckpointSync {
 	if mode == checkpointSync || mode == legacyCheckpointSync {
 		// Validate the advertised checkpoint
 		// Validate the advertised checkpoint
 		if mode == legacyCheckpointSync {
 		if mode == legacyCheckpointSync {
-			checkpoint = pm.checkpoint
+			checkpoint = h.checkpoint
 		} else if mode == checkpointSync {
 		} else if mode == checkpointSync {
-			if err := pm.validateCheckpoint(peer); err != nil {
+			if err := h.validateCheckpoint(peer); err != nil {
 				log.Debug("Failed to validate checkpoint", "reason", err)
 				log.Debug("Failed to validate checkpoint", "reason", err)
-				pm.removePeer(peer.id)
+				h.removePeer(peer.id)
 				return
 				return
 			}
 			}
-			pm.blockchain.(*light.LightChain).AddTrustedCheckpoint(checkpoint)
+			h.backend.blockchain.AddTrustedCheckpoint(checkpoint)
 		}
 		}
 		log.Debug("Checkpoint syncing start", "peer", peer.id, "checkpoint", checkpoint.SectionIndex)
 		log.Debug("Checkpoint syncing start", "peer", peer.id, "checkpoint", checkpoint.SectionIndex)
 
 
@@ -197,14 +168,14 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
 		// of the latest epoch covered by checkpoint.
 		// of the latest epoch covered by checkpoint.
 		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
 		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
 		defer cancel()
 		defer cancel()
-		if !checkpoint.Empty() && !pm.blockchain.(*light.LightChain).SyncCheckpoint(ctx, checkpoint) {
+		if !checkpoint.Empty() && !h.backend.blockchain.SyncCheckpoint(ctx, checkpoint) {
 			log.Debug("Sync checkpoint failed")
 			log.Debug("Sync checkpoint failed")
-			pm.removePeer(peer.id)
+			h.removePeer(peer.id)
 			return
 			return
 		}
 		}
 	}
 	}
 	// Fetch the remaining block headers based on the current chain header.
 	// Fetch the remaining block headers based on the current chain header.
-	if err := pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync); err != nil {
+	if err := h.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync); err != nil {
 		log.Debug("Synchronise failed", "reason", err)
 		log.Debug("Synchronise failed", "reason", err)
 		return
 		return
 	}
 	}

+ 8 - 9
les/sync_test.go

@@ -57,7 +57,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
 		}
 		}
 	}
 	}
 	// Generate 512+4 blocks (totally 1 CHT sections)
 	// Generate 512+4 blocks (totally 1 CHT sections)
-	server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false)
+	server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, nil, 0, false, false)
 	defer tearDown()
 	defer tearDown()
 
 
 	expected := config.ChtSize + config.ChtConfirms
 	expected := config.ChtSize + config.ChtConfirms
@@ -74,8 +74,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
 		}
 		}
 		if syncMode == 1 {
 		if syncMode == 1 {
 			// Register the assembled checkpoint as hardcoded one.
 			// Register the assembled checkpoint as hardcoded one.
-			client.pm.checkpoint = cp
-			client.pm.blockchain.(*light.LightChain).AddTrustedCheckpoint(cp)
+			client.handler.checkpoint = cp
+			client.handler.backend.blockchain.AddTrustedCheckpoint(cp)
 		} else {
 		} else {
 			// Register the assembled checkpoint into oracle.
 			// Register the assembled checkpoint into oracle.
 			header := server.backend.Blockchain().CurrentHeader()
 			header := server.backend.Blockchain().CurrentHeader()
@@ -83,14 +83,14 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
 			data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...)
 			data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...)
 			sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey)
 			sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey)
 			sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper
 			sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper
-			if _, err := server.pm.reg.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil {
+			if _, err := server.handler.server.oracle.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil {
 				t.Error("register checkpoint failed", err)
 				t.Error("register checkpoint failed", err)
 			}
 			}
 			server.backend.Commit()
 			server.backend.Commit()
 
 
 			// Wait for the checkpoint registration
 			// Wait for the checkpoint registration
 			for {
 			for {
-				_, hash, _, err := server.pm.reg.contract.Contract().GetLatestCheckpoint(nil)
+				_, hash, _, err := server.handler.server.oracle.contract.Contract().GetLatestCheckpoint(nil)
 				if err != nil || hash == [32]byte{} {
 				if err != nil || hash == [32]byte{} {
 					time.Sleep(100 * time.Millisecond)
 					time.Sleep(100 * time.Millisecond)
 					continue
 					continue
@@ -102,8 +102,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
 	}
 	}
 
 
 	done := make(chan error)
 	done := make(chan error)
-	client.pm.reg.syncDoneHook = func() {
-		header := client.pm.blockchain.CurrentHeader()
+	client.handler.backend.oracle.syncDoneHook = func() {
+		header := client.handler.backend.blockchain.CurrentHeader()
 		if header.Number.Uint64() == expected {
 		if header.Number.Uint64() == expected {
 			done <- nil
 			done <- nil
 		} else {
 		} else {
@@ -112,7 +112,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
 	}
 	}
 
 
 	// Create connected peer pair.
 	// Create connected peer pair.
-	peer, err1, lPeer, err2 := newTestPeerPair("peer", protocol, server.pm, client.pm)
+	_, err1, _, err2 := newTestPeerPair("peer", protocol, server.handler, client.handler)
 	select {
 	select {
 	case <-time.After(time.Millisecond * 100):
 	case <-time.After(time.Millisecond * 100):
 	case err := <-err1:
 	case err := <-err1:
@@ -120,7 +120,6 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
 	case err := <-err2:
 	case err := <-err2:
 		t.Fatalf("peer 2 handshake error: %v", err)
 		t.Fatalf("peer 2 handshake error: %v", err)
 	}
 	}
-	server.rPeer, client.rPeer = peer, lPeer
 
 
 	select {
 	select {
 	case err := <-done:
 	case err := <-done:

+ 247 - 201
les/helper_test.go → les/test_helper.go

@@ -23,7 +23,6 @@ import (
 	"context"
 	"context"
 	"crypto/rand"
 	"crypto/rand"
 	"math/big"
 	"math/big"
-	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -57,8 +56,8 @@ var (
 	userAddr1   = crypto.PubkeyToAddress(userKey1.PublicKey)
 	userAddr1   = crypto.PubkeyToAddress(userKey1.PublicKey)
 	userAddr2   = crypto.PubkeyToAddress(userKey2.PublicKey)
 	userAddr2   = crypto.PubkeyToAddress(userKey2.PublicKey)
 
 
-	testContractCode         = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056")
 	testContractAddr         common.Address
 	testContractAddr         common.Address
+	testContractCode         = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056")
 	testContractCodeDeployed = testContractCode[16:]
 	testContractCodeDeployed = testContractCode[16:]
 	testContractDeployed     = uint64(2)
 	testContractDeployed     = uint64(2)
 
 
@@ -77,8 +76,10 @@ var (
 	// The number of confirmations needed to generate a checkpoint(only used in test).
 	// The number of confirmations needed to generate a checkpoint(only used in test).
 	processConfirms = big.NewInt(4)
 	processConfirms = big.NewInt(4)
 
 
-	//
-	testBufLimit    = uint64(1000000)
+	// The token bucket buffer limit for testing purpose.
+	testBufLimit = uint64(1000000)
+
+	// The buffer recharging speed for testing purpose.
 	testBufRecharge = uint64(1000)
 	testBufRecharge = uint64(1000)
 )
 )
 
 
@@ -97,8 +98,8 @@ contract test {
 }
 }
 */
 */
 
 
-// prepareTestchain pre-commits specified number customized blocks into chain.
-func prepareTestchain(n int, backend *backends.SimulatedBackend) {
+// prepare pre-commits specified number customized blocks into chain.
+func prepare(n int, backend *backends.SimulatedBackend) {
 	var (
 	var (
 		ctx    = context.Background()
 		ctx    = context.Background()
 		signer = types.HomesteadSigner{}
 		signer = types.HomesteadSigner{}
@@ -164,51 +165,88 @@ func testIndexers(db ethdb.Database, odr light.OdrBackend, config *light.Indexer
 	return indexers[:]
 	return indexers[:]
 }
 }
 
 
-// newTestProtocolManager creates a new protocol manager for testing purposes,
-// with the given number of blocks already known, potential notification
-// channels for different events and relative chain indexers array.
-func newTestProtocolManager(lightSync bool, blocks int, odr *LesOdr, indexers []*core.ChainIndexer, peers *peerSet, db ethdb.Database, ulcServers []string, ulcFraction int, testCost uint64, clock mclock.Clock) (*ProtocolManager, *backends.SimulatedBackend, error) {
+func newTestClientHandler(backend *backends.SimulatedBackend, odr *LesOdr, indexers []*core.ChainIndexer, db ethdb.Database, peers *peerSet, ulcServers []string, ulcFraction int) *clientHandler {
 	var (
 	var (
 		evmux  = new(event.TypeMux)
 		evmux  = new(event.TypeMux)
 		engine = ethash.NewFaker()
 		engine = ethash.NewFaker()
 		gspec  = core.Genesis{
 		gspec  = core.Genesis{
-			Config: params.AllEthashProtocolChanges,
-			Alloc:  core.GenesisAlloc{bankAddr: {Balance: bankFunds}},
+			Config:   params.AllEthashProtocolChanges,
+			Alloc:    core.GenesisAlloc{bankAddr: {Balance: bankFunds}},
+			GasLimit: 100000000,
 		}
 		}
-		pool   txPool
-		chain  BlockChain
-		exitCh = make(chan struct{})
+		oracle *checkpointOracle
 	)
 	)
-	gspec.MustCommit(db)
-	if peers == nil {
-		peers = newPeerSet()
+	genesis := gspec.MustCommit(db)
+	chain, _ := light.NewLightChain(odr, gspec.Config, engine, nil)
+	if indexers != nil {
+		checkpointConfig := &params.CheckpointOracleConfig{
+			Address:   crypto.CreateAddress(bankAddr, 0),
+			Signers:   []common.Address{signerAddr},
+			Threshold: 1,
+		}
+		getLocal := func(index uint64) params.TrustedCheckpoint {
+			chtIndexer := indexers[0]
+			sectionHead := chtIndexer.SectionHead(index)
+			return params.TrustedCheckpoint{
+				SectionIndex: index,
+				SectionHead:  sectionHead,
+				CHTRoot:      light.GetChtRoot(db, index, sectionHead),
+				BloomRoot:    light.GetBloomTrieRoot(db, index, sectionHead),
+			}
+		}
+		oracle = newCheckpointOracle(checkpointConfig, getLocal)
 	}
 	}
-	// create a simulation backend and pre-commit several customized block to the database.
-	simulation := backends.NewSimulatedBackendWithDatabase(db, gspec.Alloc, 100000000)
-	prepareTestchain(blocks, simulation)
-
-	// initialize empty chain for light client or pre-committed chain for server.
-	if lightSync {
-		chain, _ = light.NewLightChain(odr, gspec.Config, engine, nil)
-	} else {
-		chain = simulation.Blockchain()
-		config := core.DefaultTxPoolConfig
-		config.Journal = ""
-		pool = core.NewTxPool(config, gspec.Config, simulation.Blockchain())
+	client := &LightEthereum{
+		lesCommons: lesCommons{
+			genesis:     genesis.Hash(),
+			config:      &eth.Config{LightPeers: 100, NetworkId: NetworkId},
+			chainConfig: params.AllEthashProtocolChanges,
+			iConfig:     light.TestClientIndexerConfig,
+			chainDb:     db,
+			oracle:      oracle,
+			chainReader: chain,
+			peers:       peers,
+			closeCh:     make(chan struct{}),
+		},
+		reqDist:    odr.retriever.dist,
+		retriever:  odr.retriever,
+		odr:        odr,
+		engine:     engine,
+		blockchain: chain,
+		eventMux:   evmux,
 	}
 	}
+	client.handler = newClientHandler(ulcServers, ulcFraction, nil, client)
 
 
-	// Create contract registrar
-	indexConfig := light.TestServerIndexerConfig
-	if lightSync {
-		indexConfig = light.TestClientIndexerConfig
+	if client.oracle != nil {
+		client.oracle.start(backend)
 	}
 	}
-	config := &params.CheckpointOracleConfig{
-		Address:   crypto.CreateAddress(bankAddr, 0),
-		Signers:   []common.Address{signerAddr},
-		Threshold: 1,
-	}
-	var reg *checkpointOracle
+	return client.handler
+}
+
+func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, peers *peerSet, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) {
+	var (
+		gspec = core.Genesis{
+			Config:   params.AllEthashProtocolChanges,
+			Alloc:    core.GenesisAlloc{bankAddr: {Balance: bankFunds}},
+			GasLimit: 100000000,
+		}
+		oracle *checkpointOracle
+	)
+	genesis := gspec.MustCommit(db)
+
+	// create a simulation backend and pre-commit several customized block to the database.
+	simulation := backends.NewSimulatedBackendWithDatabase(db, gspec.Alloc, 100000000)
+	prepare(blocks, simulation)
+
+	txpoolConfig := core.DefaultTxPoolConfig
+	txpoolConfig.Journal = ""
+	txpool := core.NewTxPool(txpoolConfig, gspec.Config, simulation.Blockchain())
 	if indexers != nil {
 	if indexers != nil {
+		checkpointConfig := &params.CheckpointOracleConfig{
+			Address:   crypto.CreateAddress(bankAddr, 0),
+			Signers:   []common.Address{signerAddr},
+			Threshold: 1,
+		}
 		getLocal := func(index uint64) params.TrustedCheckpoint {
 		getLocal := func(index uint64) params.TrustedCheckpoint {
 			chtIndexer := indexers[0]
 			chtIndexer := indexers[0]
 			sectionHead := chtIndexer.SectionHead(index)
 			sectionHead := chtIndexer.SectionHead(index)
@@ -219,72 +257,63 @@ func newTestProtocolManager(lightSync bool, blocks int, odr *LesOdr, indexers []
 				BloomRoot:    light.GetBloomTrieRoot(db, index, sectionHead),
 				BloomRoot:    light.GetBloomTrieRoot(db, index, sectionHead),
 			}
 			}
 		}
 		}
-		reg = newCheckpointOracle(config, getLocal)
-	}
-	pm, err := NewProtocolManager(gspec.Config, nil, indexConfig, ulcServers, ulcFraction, lightSync, NetworkId, evmux, peers, chain, pool, db, odr, nil, reg, exitCh, new(sync.WaitGroup), func() bool { return true })
-	if err != nil {
-		return nil, nil, err
+		oracle = newCheckpointOracle(checkpointConfig, getLocal)
 	}
 	}
-	// Registrar initialization could failed if checkpoint contract is not specified.
-	if pm.reg != nil {
-		pm.reg.start(simulation)
-	}
-	// Set up les server stuff.
-	if !lightSync {
-		srv := &LesServer{lesCommons: lesCommons{protocolManager: pm, chainDb: db}}
-		pm.server = srv
-		pm.servingQueue = newServingQueue(int64(time.Millisecond*10), 1)
-		pm.servingQueue.setThreads(4)
-
-		srv.defParams = flowcontrol.ServerParams{
+	server := &LesServer{
+		lesCommons: lesCommons{
+			genesis:     genesis.Hash(),
+			config:      &eth.Config{LightPeers: 100, NetworkId: NetworkId},
+			chainConfig: params.AllEthashProtocolChanges,
+			iConfig:     light.TestServerIndexerConfig,
+			chainDb:     db,
+			chainReader: simulation.Blockchain(),
+			oracle:      oracle,
+			peers:       peers,
+			closeCh:     make(chan struct{}),
+		},
+		servingQueue: newServingQueue(int64(time.Millisecond*10), 1),
+		defParams: flowcontrol.ServerParams{
 			BufLimit:    testBufLimit,
 			BufLimit:    testBufLimit,
 			MinRecharge: testBufRecharge,
 			MinRecharge: testBufRecharge,
-		}
-		srv.testCost = testCost
-		srv.fcManager = flowcontrol.NewClientManager(nil, clock)
+		},
+		fcManager: flowcontrol.NewClientManager(nil, clock),
 	}
 	}
-	pm.Start(1000)
-	return pm, simulation, nil
-}
-
-// newTestProtocolManagerMust creates a new protocol manager for testing purposes,
-// with the given number of blocks already known, potential notification channels
-// for different events and relative chain indexers array. In case of an error, the
-// constructor force-fails the test.
-func newTestProtocolManagerMust(t *testing.T, lightSync bool, blocks int, odr *LesOdr, indexers []*core.ChainIndexer, peers *peerSet, db ethdb.Database, ulcServers []string, ulcFraction int) (*ProtocolManager, *backends.SimulatedBackend) {
-	pm, backend, err := newTestProtocolManager(lightSync, blocks, odr, indexers, peers, db, ulcServers, ulcFraction, 0, &mclock.System{})
-	if err != nil {
-		t.Fatalf("Failed to create protocol manager: %v", err)
+	server.costTracker, server.freeCapacity = newCostTracker(db, server.config)
+	server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism.
+	server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true })
+	if server.oracle != nil {
+		server.oracle.start(simulation)
 	}
 	}
-	return pm, backend
+	server.servingQueue.setThreads(4)
+	server.handler.start()
+	return server.handler, simulation
 }
 }
 
 
 // testPeer is a simulated peer to allow testing direct network calls.
 // testPeer is a simulated peer to allow testing direct network calls.
 type testPeer struct {
 type testPeer struct {
+	peer *peer
+
 	net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
 	net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging
 	app *p2p.MsgPipeRW    // Application layer reader/writer to simulate the local side
 	app *p2p.MsgPipeRW    // Application layer reader/writer to simulate the local side
-	*peer
 }
 }
 
 
 // newTestPeer creates a new peer registered at the given protocol manager.
 // newTestPeer creates a new peer registered at the given protocol manager.
-func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, shake bool, testCost uint64) (*testPeer, <-chan error) {
+func newTestPeer(t *testing.T, name string, version int, handler *serverHandler, shake bool, testCost uint64) (*testPeer, <-chan error) {
 	// Create a message pipe to communicate through
 	// Create a message pipe to communicate through
 	app, net := p2p.MsgPipe()
 	app, net := p2p.MsgPipe()
 
 
 	// Generate a random id and create the peer
 	// Generate a random id and create the peer
 	var id enode.ID
 	var id enode.ID
 	rand.Read(id[:])
 	rand.Read(id[:])
-
-	peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
+	peer := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), net)
 
 
 	// Start the peer on a new thread
 	// Start the peer on a new thread
-	errc := make(chan error, 1)
+	errCh := make(chan error, 1)
 	go func() {
 	go func() {
 		select {
 		select {
-		case pm.newPeerCh <- peer:
-			errc <- pm.handle(peer)
-		case <-pm.quitSync:
-			errc <- p2p.DiscQuitting
+		case <-handler.closeCh:
+			errCh <- p2p.DiscQuitting
+		case errCh <- handler.handle(peer):
 		}
 		}
 	}()
 	}()
 	tp := &testPeer{
 	tp := &testPeer{
@@ -294,17 +323,27 @@ func newTestPeer(t *testing.T, name string, version int, pm *ProtocolManager, sh
 	}
 	}
 	// Execute any implicitly requested handshakes and return
 	// Execute any implicitly requested handshakes and return
 	if shake {
 	if shake {
+		// Customize the cost table if required.
+		if testCost != 0 {
+			handler.server.costTracker.testCostList = testCostList(testCost)
+		}
 		var (
 		var (
-			genesis = pm.blockchain.Genesis()
-			head    = pm.blockchain.CurrentHeader()
-			td      = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64())
+			genesis = handler.blockchain.Genesis()
+			head    = handler.blockchain.CurrentHeader()
+			td      = handler.blockchain.GetTd(head.Hash(), head.Number.Uint64())
 		)
 		)
-		tp.handshake(t, td, head.Hash(), head.Number.Uint64(), genesis.Hash(), testCost)
+		tp.handshake(t, td, head.Hash(), head.Number.Uint64(), genesis.Hash(), testCostList(testCost))
 	}
 	}
-	return tp, errc
+	return tp, errCh
+}
+
+// close terminates the local side of the peer, notifying the remote protocol
+// manager of termination.
+func (p *testPeer) close() {
+	p.app.Close()
 }
 }
 
 
-func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer, <-chan error, *peer, <-chan error) {
+func newTestPeerPair(name string, version int, server *serverHandler, client *clientHandler) (*testPeer, <-chan error, *testPeer, <-chan error) {
 	// Create a message pipe to communicate through
 	// Create a message pipe to communicate through
 	app, net := p2p.MsgPipe()
 	app, net := p2p.MsgPipe()
 
 
@@ -312,36 +351,34 @@ func newTestPeerPair(name string, version int, pm, pm2 *ProtocolManager) (*peer,
 	var id enode.ID
 	var id enode.ID
 	rand.Read(id[:])
 	rand.Read(id[:])
 
 
-	peer := pm.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), net)
-	peer2 := pm2.newPeer(version, NetworkId, p2p.NewPeer(id, name, nil), app)
+	peer1 := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), net)
+	peer2 := newPeer(version, NetworkId, false, p2p.NewPeer(id, name, nil), app)
 
 
 	// Start the peer on a new thread
 	// Start the peer on a new thread
-	errc := make(chan error, 1)
+	errc1 := make(chan error, 1)
 	errc2 := make(chan error, 1)
 	errc2 := make(chan error, 1)
 	go func() {
 	go func() {
 		select {
 		select {
-		case pm.newPeerCh <- peer:
-			errc <- pm.handle(peer)
-		case <-pm.quitSync:
-			errc <- p2p.DiscQuitting
+		case <-server.closeCh:
+			errc1 <- p2p.DiscQuitting
+		case errc1 <- server.handle(peer1):
 		}
 		}
 	}()
 	}()
 	go func() {
 	go func() {
 		select {
 		select {
-		case pm2.newPeerCh <- peer2:
-			errc2 <- pm2.handle(peer2)
-		case <-pm2.quitSync:
-			errc2 <- p2p.DiscQuitting
+		case <-client.closeCh:
+			errc1 <- p2p.DiscQuitting
+		case errc1 <- client.handle(peer2):
 		}
 		}
 	}()
 	}()
-	return peer, errc, peer2, errc2
+	return &testPeer{peer: peer1, net: net, app: app}, errc1, &testPeer{peer: peer2, net: app, app: net}, errc2
 }
 }
 
 
 // handshake simulates a trivial handshake that expects the same state from the
 // handshake simulates a trivial handshake that expects the same state from the
 // remote side as we are simulating locally.
 // remote side as we are simulating locally.
-func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, testCost uint64) {
+func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, costList RequestCostList) {
 	var expList keyValueList
 	var expList keyValueList
-	expList = expList.add("protocolVersion", uint64(p.version))
+	expList = expList.add("protocolVersion", uint64(p.peer.version))
 	expList = expList.add("networkId", uint64(NetworkId))
 	expList = expList.add("networkId", uint64(NetworkId))
 	expList = expList.add("headTd", td)
 	expList = expList.add("headTd", td)
 	expList = expList.add("headHash", head)
 	expList = expList.add("headHash", head)
@@ -356,7 +393,7 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
 	expList = expList.add("txRelay", nil)
 	expList = expList.add("txRelay", nil)
 	expList = expList.add("flowControl/BL", testBufLimit)
 	expList = expList.add("flowControl/BL", testBufLimit)
 	expList = expList.add("flowControl/MRR", testBufRecharge)
 	expList = expList.add("flowControl/MRR", testBufRecharge)
-	expList = expList.add("flowControl/MRC", testCostList(testCost))
+	expList = expList.add("flowControl/MRC", costList)
 
 
 	if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil {
 	if err := p2p.ExpectMsg(p.app, StatusMsg, expList); err != nil {
 		t.Fatalf("status recv: %v", err)
 		t.Fatalf("status recv: %v", err)
@@ -364,113 +401,119 @@ func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, headNu
 	if err := p2p.Send(p.app, StatusMsg, sendList); err != nil {
 	if err := p2p.Send(p.app, StatusMsg, sendList); err != nil {
 		t.Fatalf("status send: %v", err)
 		t.Fatalf("status send: %v", err)
 	}
 	}
-
-	p.fcParams = flowcontrol.ServerParams{
+	p.peer.fcParams = flowcontrol.ServerParams{
 		BufLimit:    testBufLimit,
 		BufLimit:    testBufLimit,
 		MinRecharge: testBufRecharge,
 		MinRecharge: testBufRecharge,
 	}
 	}
 }
 }
 
 
-// close terminates the local side of the peer, notifying the remote protocol
-// manager of termination.
-func (p *testPeer) close() {
-	p.app.Close()
-}
+type indexerCallback func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer)
 
 
-// TestEntity represents a network entity for testing with necessary auxiliary fields.
-type TestEntity struct {
+// testClient represents a client for testing with necessary auxiliary fields.
+type testClient struct {
+	clock   mclock.Clock
 	db      ethdb.Database
 	db      ethdb.Database
-	rPeer   *peer
-	tPeer   *testPeer
-	peers   *peerSet
-	pm      *ProtocolManager
+	peer    *testPeer
+	handler *clientHandler
+
+	chtIndexer       *core.ChainIndexer
+	bloomIndexer     *core.ChainIndexer
+	bloomTrieIndexer *core.ChainIndexer
+}
+
+// testServer represents a server for testing with necessary auxiliary fields.
+type testServer struct {
+	clock   mclock.Clock
 	backend *backends.SimulatedBackend
 	backend *backends.SimulatedBackend
+	db      ethdb.Database
+	peer    *testPeer
+	handler *serverHandler
 
 
-	// Indexers
 	chtIndexer       *core.ChainIndexer
 	chtIndexer       *core.ChainIndexer
 	bloomIndexer     *core.ChainIndexer
 	bloomIndexer     *core.ChainIndexer
 	bloomTrieIndexer *core.ChainIndexer
 	bloomTrieIndexer *core.ChainIndexer
 }
 }
 
 
-// newServerEnv creates a server testing environment with a connected test peer for testing purpose.
-func newServerEnv(t *testing.T, blocks int, protocol int, waitIndexers func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer)) (*TestEntity, func()) {
+func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, simClock bool, newPeer bool, testCost uint64) (*testServer, func()) {
 	db := rawdb.NewMemoryDatabase()
 	db := rawdb.NewMemoryDatabase()
 	indexers := testIndexers(db, nil, light.TestServerIndexerConfig)
 	indexers := testIndexers(db, nil, light.TestServerIndexerConfig)
 
 
-	pm, b := newTestProtocolManagerMust(t, false, blocks, nil, indexers, nil, db, nil, 0)
-	peer, _ := newTestPeer(t, "peer", protocol, pm, true, 0)
+	var clock mclock.Clock = &mclock.System{}
+	if simClock {
+		clock = &mclock.Simulated{}
+	}
+	handler, b := newTestServerHandler(blocks, indexers, db, newPeerSet(), clock)
+
+	var peer *testPeer
+	if newPeer {
+		peer, _ = newTestPeer(t, "peer", protocol, handler, true, testCost)
+	}
 
 
 	cIndexer, bIndexer, btIndexer := indexers[0], indexers[1], indexers[2]
 	cIndexer, bIndexer, btIndexer := indexers[0], indexers[1], indexers[2]
-	cIndexer.Start(pm.blockchain.(*core.BlockChain))
-	bIndexer.Start(pm.blockchain.(*core.BlockChain))
+	cIndexer.Start(handler.blockchain)
+	bIndexer.Start(handler.blockchain)
 
 
 	// Wait until indexers generate enough index data.
 	// Wait until indexers generate enough index data.
-	if waitIndexers != nil {
-		waitIndexers(cIndexer, bIndexer, btIndexer)
+	if callback != nil {
+		callback(cIndexer, bIndexer, btIndexer)
 	}
 	}
-
-	return &TestEntity{
-			db:               db,
-			tPeer:            peer,
-			pm:               pm,
-			backend:          b,
-			chtIndexer:       cIndexer,
-			bloomIndexer:     bIndexer,
-			bloomTrieIndexer: btIndexer,
-		}, func() {
+	server := &testServer{
+		clock:            clock,
+		backend:          b,
+		db:               db,
+		peer:             peer,
+		handler:          handler,
+		chtIndexer:       cIndexer,
+		bloomIndexer:     bIndexer,
+		bloomTrieIndexer: btIndexer,
+	}
+	teardown := func() {
+		if newPeer {
 			peer.close()
 			peer.close()
-			// Note bloom trie indexer will be closed by it parent recursively.
-			cIndexer.Close()
-			bIndexer.Close()
 			b.Close()
 			b.Close()
 		}
 		}
+		cIndexer.Close()
+		bIndexer.Close()
+	}
+	return server, teardown
 }
 }
 
 
-// newClientServerEnv creates a client/server arch environment with a connected les server and light client pair
-// for testing purpose.
-func newClientServerEnv(t *testing.T, blocks int, protocol int, waitIndexers func(*core.ChainIndexer, *core.ChainIndexer, *core.ChainIndexer), newPeer bool) (*TestEntity, *TestEntity, func()) {
-	db, ldb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase()
-	peers, lPeers := newPeerSet(), newPeerSet()
+func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, ulcServers []string, ulcFraction int, simClock bool, connect bool) (*testServer, *testClient, func()) {
+	sdb, cdb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase()
+	speers, cPeers := newPeerSet(), newPeerSet()
 
 
-	dist := newRequestDistributor(lPeers, make(chan struct{}), &mclock.System{})
-	rm := newRetrieveManager(lPeers, dist, nil)
-	odr := NewLesOdr(ldb, light.TestClientIndexerConfig, rm)
+	var clock mclock.Clock = &mclock.System{}
+	if simClock {
+		clock = &mclock.Simulated{}
+	}
+	dist := newRequestDistributor(cPeers, clock)
+	rm := newRetrieveManager(cPeers, dist, nil)
+	odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm)
 
 
-	indexers := testIndexers(db, nil, light.TestServerIndexerConfig)
-	lIndexers := testIndexers(ldb, odr, light.TestClientIndexerConfig)
+	sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig)
+	cIndexers := testIndexers(cdb, odr, light.TestClientIndexerConfig)
 
 
-	cIndexer, bIndexer, btIndexer := indexers[0], indexers[1], indexers[2]
-	lcIndexer, lbIndexer, lbtIndexer := lIndexers[0], lIndexers[1], lIndexers[2]
+	scIndexer, sbIndexer, sbtIndexer := sindexers[0], sindexers[1], sindexers[2]
+	ccIndexer, cbIndexer, cbtIndexer := cIndexers[0], cIndexers[1], cIndexers[2]
+	odr.SetIndexers(ccIndexer, cbIndexer, cbtIndexer)
 
 
-	odr.SetIndexers(lcIndexer, lbtIndexer, lbIndexer)
+	server, b := newTestServerHandler(blocks, sindexers, sdb, speers, clock)
+	client := newTestClientHandler(b, odr, cIndexers, cdb, cPeers, ulcServers, ulcFraction)
 
 
-	pm, b := newTestProtocolManagerMust(t, false, blocks, nil, indexers, peers, db, nil, 0)
-	lpm, lb := newTestProtocolManagerMust(t, true, 0, odr, lIndexers, lPeers, ldb, nil, 0)
+	scIndexer.Start(server.blockchain)
+	sbIndexer.Start(server.blockchain)
+	ccIndexer.Start(client.backend.blockchain)
+	cbIndexer.Start(client.backend.blockchain)
 
 
-	startIndexers := func(clientMode bool, pm *ProtocolManager) {
-		if clientMode {
-			lcIndexer.Start(pm.blockchain.(*light.LightChain))
-			lbIndexer.Start(pm.blockchain.(*light.LightChain))
-		} else {
-			cIndexer.Start(pm.blockchain.(*core.BlockChain))
-			bIndexer.Start(pm.blockchain.(*core.BlockChain))
-		}
+	if callback != nil {
+		callback(scIndexer, sbIndexer, sbtIndexer)
 	}
 	}
-
-	startIndexers(false, pm)
-	startIndexers(true, lpm)
-
-	// Execute wait until function if it is specified.
-	if waitIndexers != nil {
-		waitIndexers(cIndexer, bIndexer, btIndexer)
-	}
-
 	var (
 	var (
-		peer, lPeer *peer
-		err1, err2  <-chan error
+		speer, cpeer *testPeer
+		err1, err2   <-chan error
 	)
 	)
-	if newPeer {
-		peer, err1, lPeer, err2 = newTestPeerPair("peer", protocol, pm, lpm)
+	if connect {
+		cpeer, err1, speer, err2 = newTestPeerPair("peer", protocol, server, client)
 		select {
 		select {
 		case <-time.After(time.Millisecond * 100):
 		case <-time.After(time.Millisecond * 100):
 		case err := <-err1:
 		case err := <-err1:
@@ -479,32 +522,35 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, waitIndexers fun
 			t.Fatalf("peer 2 handshake error: %v", err)
 			t.Fatalf("peer 2 handshake error: %v", err)
 		}
 		}
 	}
 	}
-
-	return &TestEntity{
-			db:               db,
-			pm:               pm,
-			rPeer:            peer,
-			peers:            peers,
-			backend:          b,
-			chtIndexer:       cIndexer,
-			bloomIndexer:     bIndexer,
-			bloomTrieIndexer: btIndexer,
-		}, &TestEntity{
-			db:               ldb,
-			pm:               lpm,
-			rPeer:            lPeer,
-			peers:            lPeers,
-			backend:          lb,
-			chtIndexer:       lcIndexer,
-			bloomIndexer:     lbIndexer,
-			bloomTrieIndexer: lbtIndexer,
-		}, func() {
-			// Note bloom trie indexers will be closed by their parents recursively.
-			cIndexer.Close()
-			bIndexer.Close()
-			lcIndexer.Close()
-			lbIndexer.Close()
-			b.Close()
-			lb.Close()
+	s := &testServer{
+		clock:            clock,
+		backend:          b,
+		db:               sdb,
+		peer:             cpeer,
+		handler:          server,
+		chtIndexer:       scIndexer,
+		bloomIndexer:     sbIndexer,
+		bloomTrieIndexer: sbtIndexer,
+	}
+	c := &testClient{
+		clock:            clock,
+		db:               cdb,
+		peer:             speer,
+		handler:          client,
+		chtIndexer:       ccIndexer,
+		bloomIndexer:     cbIndexer,
+		bloomTrieIndexer: cbtIndexer,
+	}
+	teardown := func() {
+		if connect {
+			speer.close()
+			cpeer.close()
 		}
 		}
+		ccIndexer.Close()
+		cbIndexer.Close()
+		scIndexer.Close()
+		sbIndexer.Close()
+		b.Close()
+	}
+	return s, c, teardown
 }
 }

+ 75 - 151
les/ulc_test.go

@@ -17,151 +17,100 @@
 package les
 package les
 
 
 import (
 import (
-	"crypto/ecdsa"
+	"crypto/rand"
 	"fmt"
 	"fmt"
-	"math/big"
 	"net"
 	"net"
-	"reflect"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
-	"github.com/ethereum/go-ethereum/common/mclock"
-	"github.com/ethereum/go-ethereum/core/rawdb"
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/crypto"
-	"github.com/ethereum/go-ethereum/light"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 )
 )
 
 
-func TestULCSyncWithOnePeer(t *testing.T) {
-	f := newFullPeerPair(t, 1, 4)
-	l := newLightPeer(t, []string{f.Node.String()}, 100)
-
-	if reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) {
-		t.Fatal("blocks are equal")
-	}
-	_, _, err := connectPeers(f, l, 2)
-	if err != nil {
-		t.Fatal(err)
-	}
-	l.PM.fetcher.lock.Lock()
-	l.PM.fetcher.nextRequest()
-	l.PM.fetcher.lock.Unlock()
-
-	if !reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) {
-		t.Fatal("sync doesn't work")
-	}
-}
-
-func TestULCReceiveAnnounce(t *testing.T) {
-	f := newFullPeerPair(t, 1, 4)
-	l := newLightPeer(t, []string{f.Node.String()}, 100)
-	fPeer, lPeer, err := connectPeers(f, l, 2)
-	if err != nil {
-		t.Fatal(err)
-	}
-	l.PM.synchronise(fPeer)
-
-	//check that the sync is finished correctly
-	if !reflect.DeepEqual(f.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) {
-		t.Fatal("sync doesn't work")
-	}
-	l.PM.peers.lock.Lock()
-	if len(l.PM.peers.peers) == 0 {
-		t.Fatal("peer list should not be empty")
-	}
-	l.PM.peers.lock.Unlock()
-
-	time.Sleep(time.Second)
-	//send a signed announce message(payload doesn't matter)
-	td := f.PM.blockchain.GetTd(l.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Number.Uint64())
-	announce := announceData{
-		Number: l.PM.blockchain.CurrentHeader().Number.Uint64() + 1,
-		Td:     td.Add(td, big.NewInt(1)),
-	}
-	announce.sign(f.Key)
-	lPeer.SendAnnounce(announce)
-}
-
-func TestULCShouldNotSyncWithTwoPeersOneHaveEmptyChain(t *testing.T) {
-	f1 := newFullPeerPair(t, 1, 4)
-	f2 := newFullPeerPair(t, 2, 0)
-	l := newLightPeer(t, []string{f1.Node.String(), f2.Node.String()}, 100)
-	_, _, err := connectPeers(f1, l, 2)
-	if err != nil {
-		t.Fatal(err)
-	}
-	_, _, err = connectPeers(f2, l, 2)
-	if err != nil {
-		t.Fatal(err)
-	}
-	l.PM.fetcher.lock.Lock()
-	l.PM.fetcher.nextRequest()
-	l.PM.fetcher.lock.Unlock()
-
-	if reflect.DeepEqual(f2.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) {
-		t.Fatal("Incorrect hash: second peer has empty chain")
-	}
-}
-
-func TestULCShouldNotSyncWithThreePeersOneHaveEmptyChain(t *testing.T) {
-	f1 := newFullPeerPair(t, 1, 3)
-	f2 := newFullPeerPair(t, 2, 4)
-	f3 := newFullPeerPair(t, 3, 0)
+func TestULCAnnounceThresholdLes2(t *testing.T) { testULCAnnounceThreshold(t, 2) }
+func TestULCAnnounceThresholdLes3(t *testing.T) { testULCAnnounceThreshold(t, 3) }
+
+func testULCAnnounceThreshold(t *testing.T, protocol int) {
+	// todo figure out why it takes fetcher so longer to fetcher the announced header.
+	t.Skip("Sometimes it can failed")
+	var cases = []struct {
+		height    []int
+		threshold int
+		expect    uint64
+	}{
+		{[]int{1}, 100, 1},
+		{[]int{0, 0, 0}, 100, 0},
+		{[]int{1, 2, 3}, 30, 3},
+		{[]int{1, 2, 3}, 60, 2},
+		{[]int{3, 2, 1}, 67, 1},
+		{[]int{3, 2, 1}, 100, 1},
+	}
+	for _, testcase := range cases {
+		var (
+			servers   []*testServer
+			teardowns []func()
+			nodes     []*enode.Node
+			ids       []string
+		)
+		for i := 0; i < len(testcase.height); i++ {
+			s, n, teardown := newServerPeer(t, 0, protocol)
+
+			servers = append(servers, s)
+			nodes = append(nodes, n)
+			teardowns = append(teardowns, teardown)
+			ids = append(ids, n.String())
+		}
+		c, teardown := newLightPeer(t, protocol, ids, testcase.threshold)
 
 
-	l := newLightPeer(t, []string{f1.Node.String(), f2.Node.String(), f3.Node.String()}, 60)
-	_, _, err := connectPeers(f1, l, 2)
-	if err != nil {
-		t.Fatal(err)
-	}
-	_, _, err = connectPeers(f2, l, 2)
-	if err != nil {
-		t.Fatal(err)
-	}
-	_, _, err = connectPeers(f3, l, 2)
-	if err != nil {
-		t.Fatal(err)
-	}
-	l.PM.fetcher.lock.Lock()
-	l.PM.fetcher.nextRequest()
-	l.PM.fetcher.lock.Unlock()
+		// Connect all servers.
+		for i := 0; i < len(servers); i++ {
+			connect(servers[i].handler, nodes[i].ID(), c.handler, protocol)
+		}
+		for i := 0; i < len(servers); i++ {
+			for j := 0; j < testcase.height[i]; j++ {
+				servers[i].backend.Commit()
+			}
+		}
+		time.Sleep(1500 * time.Millisecond) // Ensure the fetcher has done its work.
+		head := c.handler.backend.blockchain.CurrentHeader().Number.Uint64()
+		if head != testcase.expect {
+			t.Fatalf("chain height mismatch, want %d, got %d", testcase.expect, head)
+		}
 
 
-	if !reflect.DeepEqual(f1.PM.blockchain.CurrentHeader().Hash(), l.PM.blockchain.CurrentHeader().Hash()) {
-		t.Fatal("Incorrect hash")
+		// Release all servers and client resources.
+		teardown()
+		for i := 0; i < len(teardowns); i++ {
+			teardowns[i]()
+		}
 	}
 	}
 }
 }
 
 
-type pairPeer struct {
-	Name string
-	Node *enode.Node
-	PM   *ProtocolManager
-	Key  *ecdsa.PrivateKey
-}
-
-func connectPeers(full, light pairPeer, version int) (*peer, *peer, error) {
+func connect(server *serverHandler, serverId enode.ID, client *clientHandler, protocol int) (*peer, *peer, error) {
 	// Create a message pipe to communicate through
 	// Create a message pipe to communicate through
 	app, net := p2p.MsgPipe()
 	app, net := p2p.MsgPipe()
 
 
-	peerLight := full.PM.newPeer(version, NetworkId, p2p.NewPeer(light.Node.ID(), light.Name, nil), net)
-	peerFull := light.PM.newPeer(version, NetworkId, p2p.NewPeer(full.Node.ID(), full.Name, nil), app)
+	var id enode.ID
+	rand.Read(id[:])
+
+	peer1 := newPeer(protocol, NetworkId, true, p2p.NewPeer(serverId, "", nil), net) // Mark server as trusted
+	peer2 := newPeer(protocol, NetworkId, false, p2p.NewPeer(id, "", nil), app)
 
 
 	// Start the peerLight on a new thread
 	// Start the peerLight on a new thread
 	errc1 := make(chan error, 1)
 	errc1 := make(chan error, 1)
 	errc2 := make(chan error, 1)
 	errc2 := make(chan error, 1)
 	go func() {
 	go func() {
 		select {
 		select {
-		case light.PM.newPeerCh <- peerFull:
-			errc1 <- light.PM.handle(peerFull)
-		case <-light.PM.quitSync:
+		case <-server.closeCh:
 			errc1 <- p2p.DiscQuitting
 			errc1 <- p2p.DiscQuitting
+		case errc1 <- server.handle(peer2):
 		}
 		}
 	}()
 	}()
 	go func() {
 	go func() {
 		select {
 		select {
-		case full.PM.newPeerCh <- peerLight:
-			errc2 <- full.PM.handle(peerLight)
-		case <-full.PM.quitSync:
-			errc2 <- p2p.DiscQuitting
+		case <-client.closeCh:
+			errc1 <- p2p.DiscQuitting
+		case errc1 <- client.handle(peer1):
 		}
 		}
 	}()
 	}()
 
 
@@ -172,48 +121,23 @@ func connectPeers(full, light pairPeer, version int) (*peer, *peer, error) {
 	case err := <-errc2:
 	case err := <-errc2:
 		return nil, nil, fmt.Errorf("peerFull handshake error: %v", err)
 		return nil, nil, fmt.Errorf("peerFull handshake error: %v", err)
 	}
 	}
-
-	return peerFull, peerLight, nil
+	return peer1, peer2, nil
 }
 }
 
 
-// newFullPeerPair creates node with full sync mode
-func newFullPeerPair(t *testing.T, index int, numberOfblocks int) pairPeer {
-	db := rawdb.NewMemoryDatabase()
-
-	pmFull, _ := newTestProtocolManagerMust(t, false, numberOfblocks, nil, nil, nil, db, nil, 0)
-
-	peerPairFull := pairPeer{
-		Name: "full node",
-		PM:   pmFull,
-	}
+// newServerPeer creates server peer.
+func newServerPeer(t *testing.T, blocks int, protocol int) (*testServer, *enode.Node, func()) {
+	s, teardown := newServerEnv(t, blocks, protocol, nil, false, false, 0)
 	key, err := crypto.GenerateKey()
 	key, err := crypto.GenerateKey()
 	if err != nil {
 	if err != nil {
 		t.Fatal("generate key err:", err)
 		t.Fatal("generate key err:", err)
 	}
 	}
-	peerPairFull.Key = key
-	peerPairFull.Node = enode.NewV4(&key.PublicKey, net.ParseIP("127.0.0.1"), 35000, 35000)
-	return peerPairFull
+	s.handler.server.privateKey = key
+	n := enode.NewV4(&key.PublicKey, net.ParseIP("127.0.0.1"), 35000, 35000)
+	return s, n, teardown
 }
 }
 
 
 // newLightPeer creates node with light sync mode
 // newLightPeer creates node with light sync mode
-func newLightPeer(t *testing.T, ulcServers []string, ulcFraction int) pairPeer {
-	peers := newPeerSet()
-	dist := newRequestDistributor(peers, make(chan struct{}), &mclock.System{})
-	rm := newRetrieveManager(peers, dist, nil)
-	ldb := rawdb.NewMemoryDatabase()
-
-	odr := NewLesOdr(ldb, light.DefaultClientIndexerConfig, rm)
-
-	pmLight, _ := newTestProtocolManagerMust(t, true, 0, odr, nil, peers, ldb, ulcServers, ulcFraction)
-	peerPairLight := pairPeer{
-		Name: "ulc node",
-		PM:   pmLight,
-	}
-	key, err := crypto.GenerateKey()
-	if err != nil {
-		t.Fatal("generate key err:", err)
-	}
-	peerPairLight.Key = key
-	peerPairLight.Node = enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000)
-	return peerPairLight
+func newLightPeer(t *testing.T, protocol int, ulcServers []string, ulcFraction int) (*testClient, func()) {
+	_, c, teardown := newClientServerEnv(t, 0, protocol, nil, ulcServers, ulcFraction, false, false)
+	return c, teardown
 }
 }

+ 3 - 3
light/odr_util.go

@@ -60,7 +60,7 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ
 		}
 		}
 	}
 	}
 	if number >= chtCount*odr.IndexerConfig().ChtSize {
 	if number >= chtCount*odr.IndexerConfig().ChtSize {
-		return nil, ErrNoTrustedCht
+		return nil, errNoTrustedCht
 	}
 	}
 	r := &ChtRequest{ChtRoot: GetChtRoot(db, chtCount-1, sectionHead), ChtNum: chtCount - 1, BlockNum: number, Config: odr.IndexerConfig()}
 	r := &ChtRequest{ChtRoot: GetChtRoot(db, chtCount-1, sectionHead), ChtNum: chtCount - 1, BlockNum: number, Config: odr.IndexerConfig()}
 	if err := odr.Retrieve(ctx, r); err != nil {
 	if err := odr.Retrieve(ctx, r); err != nil {
@@ -124,7 +124,7 @@ func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint
 	// Retrieve the block header and body contents
 	// Retrieve the block header and body contents
 	header := rawdb.ReadHeader(odr.Database(), hash, number)
 	header := rawdb.ReadHeader(odr.Database(), hash, number)
 	if header == nil {
 	if header == nil {
-		return nil, ErrNoHeader
+		return nil, errNoHeader
 	}
 	}
 	body, err := GetBody(ctx, odr, hash, number)
 	body, err := GetBody(ctx, odr, hash, number)
 	if err != nil {
 	if err != nil {
@@ -241,7 +241,7 @@ func GetBloomBits(ctx context.Context, odr OdrBackend, bitIdx uint, sectionIdxLi
 		} else {
 		} else {
 			// TODO(rjl493456442) Convert sectionIndex to BloomTrie relative index
 			// TODO(rjl493456442) Convert sectionIndex to BloomTrie relative index
 			if sectionIdx >= bloomTrieCount {
 			if sectionIdx >= bloomTrieCount {
-				return nil, ErrNoTrustedBloomTrie
+				return nil, errNoTrustedBloomTrie
 			}
 			}
 			reqList = append(reqList, sectionIdx)
 			reqList = append(reqList, sectionIdx)
 			reqIdx = append(reqIdx, i)
 			reqIdx = append(reqIdx, i)

+ 3 - 3
light/postprocess.go

@@ -98,9 +98,9 @@ var (
 )
 )
 
 
 var (
 var (
-	ErrNoTrustedCht       = errors.New("no trusted canonical hash trie")
-	ErrNoTrustedBloomTrie = errors.New("no trusted bloom trie")
-	ErrNoHeader           = errors.New("header not found")
+	errNoTrustedCht       = errors.New("no trusted canonical hash trie")
+	errNoTrustedBloomTrie = errors.New("no trusted bloom trie")
+	errNoHeader           = errors.New("header not found")
 	chtPrefix             = []byte("chtRootV2-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash
 	chtPrefix             = []byte("chtRootV2-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash
 	ChtTablePrefix        = "cht-"
 	ChtTablePrefix        = "cht-"
 )
 )