Răsfoiți Sursa

les: wait for all task goroutines before dropping the peer (#20010)

* les: wait all task routines before drop the peer

* les: address comments

* les: fix issue
gary rong 6 ani în urmă
părinte
comite
68502595f6
7 a modificat fișierele cu 84 adăugiri și 53 ștergeri
  1. 2 1
      les/benchmark.go
  2. 38 30
      les/clientpool.go
  3. 5 5
      les/clientpool_test.go
  4. 5 5
      les/peer.go
  5. 0 2
      les/server.go
  6. 32 10
      les/server_handler.go
  7. 2 0
      les/test_helper.go

+ 2 - 1
les/benchmark.go

@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"math/big"
 	"math/rand"
+	"sync"
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
@@ -312,7 +313,7 @@ func (h *serverHandler) measure(setup *benchmarkSetup, count int) error {
 	}()
 	go func() {
 		for i := 0; i < count; i++ {
-			if err := h.handleMsg(serverPeer); err != nil {
+			if err := h.handleMsg(serverPeer, &sync.WaitGroup{}); err != nil {
 				errCh <- err
 				return
 			}

+ 38 - 30
les/clientpool.go

@@ -181,52 +181,53 @@ func (f *clientPool) stop() {
 	f.lock.Unlock()
 }
 
-// registerPeer implements peerSetNotify
-func (f *clientPool) registerPeer(p *peer) {
-	c := f.connect(p, 0)
-	if c != nil {
-		p.balanceTracker = &c.balanceTracker
-	}
-}
-
 // connect should be called after a successful handshake. If the connection was
 // rejected, there is no need to call disconnect.
-func (f *clientPool) connect(peer clientPeer, capacity uint64) *clientInfo {
+func (f *clientPool) connect(peer clientPeer, capacity uint64) bool {
 	f.lock.Lock()
 	defer f.lock.Unlock()
 
+	// Short circuit is clientPool is already closed.
 	if f.closed {
-		return nil
+		return false
 	}
-	address := peer.freeClientId()
-	id := peer.ID()
-	idStr := peerIdToString(id)
+	// Dedup connected peers.
+	id, freeID := peer.ID(), peer.freeClientId()
 	if _, ok := f.connectedMap[id]; ok {
 		clientRejectedMeter.Mark(1)
-		log.Debug("Client already connected", "address", address, "id", idStr)
-		return nil
+		log.Debug("Client already connected", "address", freeID, "id", peerIdToString(id))
+		return false
 	}
+	// Create a clientInfo but do not add it yet
 	now := f.clock.Now()
-	// create a clientInfo but do not add it yet
-	e := &clientInfo{pool: f, peer: peer, address: address, queueIndex: -1, id: id}
 	posBalance := f.getPosBalance(id).value
-	e.priority = posBalance != 0
+	e := &clientInfo{pool: f, peer: peer, address: freeID, queueIndex: -1, id: id, priority: posBalance != 0}
+
 	var negBalance uint64
-	nb := f.negBalanceMap[address]
+	nb := f.negBalanceMap[freeID]
 	if nb != nil {
 		negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(now)) / fixedPointMultiplier))
 	}
+	// If the client is a free client, assign with a low free capacity,
+	// Otherwise assign with the given value(priority client)
 	if !e.priority {
 		capacity = f.freeClientCap
 	}
-	// check whether it fits into connectedQueue
+	// Ensure the capacity will never lower than the free capacity.
 	if capacity < f.freeClientCap {
 		capacity = f.freeClientCap
 	}
 	e.capacity = capacity
+
 	e.balanceTracker.init(f.clock, capacity)
 	e.balanceTracker.setBalance(posBalance, negBalance)
 	f.setClientPriceFactors(e)
+
+	// If the number of clients already connected in the clientpool exceeds its
+	// capacity, evict some clients with lowest priority.
+	//
+	// If the priority of the newly added client is lower than the priority of
+	// all connected clients, the client is rejected.
 	newCapacity := f.connectedCapacity + capacity
 	newCount := f.connectedQueue.Size() + 1
 	if newCapacity > f.capacityLimit || newCount > f.countLimit {
@@ -248,8 +249,8 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) *clientInfo {
 				f.connectedQueue.Push(c)
 			}
 			clientRejectedMeter.Mark(1)
-			log.Debug("Client rejected", "address", address, "id", idStr)
-			return nil
+			log.Debug("Client rejected", "address", freeID, "id", peerIdToString(id))
+			return false
 		}
 		// accept new client, drop old ones
 		for _, c := range kickList {
@@ -258,7 +259,7 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) *clientInfo {
 	}
 	// client accepted, finish setting it up
 	if nb != nil {
-		delete(f.negBalanceMap, address)
+		delete(f.negBalanceMap, freeID)
 		f.negBalanceQueue.Remove(nb.queueIndex)
 	}
 	if e.priority {
@@ -272,13 +273,8 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) *clientInfo {
 		e.peer.updateCapacity(e.capacity)
 	}
 	clientConnectedMeter.Mark(1)
-	log.Debug("Client accepted", "address", address)
-	return e
-}
-
-// unregisterPeer implements peerSetNotify
-func (f *clientPool) unregisterPeer(p *peer) {
-	f.disconnect(p)
+	log.Debug("Client accepted", "address", freeID)
+	return true
 }
 
 // disconnect should be called when a connection is terminated. If the disconnection
@@ -378,6 +374,18 @@ func (f *clientPool) setLimits(count int, totalCap uint64) {
 	})
 }
 
+// requestCost feeds request cost after serving a request from the given peer.
+func (f *clientPool) requestCost(p *peer, cost uint64) {
+	f.lock.Lock()
+	defer f.lock.Unlock()
+
+	info, exist := f.connectedMap[p.ID()]
+	if !exist || f.closed {
+		return
+	}
+	info.balanceTracker.requestCost(cost)
+}
+
 // logOffset calculates the time-dependent offset for the logarithmic
 // representation of negative balance
 func (f *clientPool) logOffset(now mclock.AbsTime) int64 {

+ 5 - 5
les/clientpool_test.go

@@ -83,14 +83,14 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
 
 	// pool should accept new peers up to its connected limit
 	for i := 0; i < connLimit; i++ {
-		if pool.connect(poolTestPeer(i), 0) != nil {
+		if pool.connect(poolTestPeer(i), 0) {
 			connected[i] = true
 		} else {
 			t.Fatalf("Test peer #%d rejected", i)
 		}
 	}
 	// since all accepted peers are new and should not be kicked out, the next one should be rejected
-	if pool.connect(poolTestPeer(connLimit), 0) != nil {
+	if pool.connect(poolTestPeer(connLimit), 0) {
 		connected[connLimit] = true
 		t.Fatalf("Peer accepted over connected limit")
 	}
@@ -116,7 +116,7 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
 				connTicks[i] += tickCounter
 			}
 		} else {
-			if pool.connect(poolTestPeer(i), 0) != nil {
+			if pool.connect(poolTestPeer(i), 0) {
 				connected[i] = true
 				connTicks[i] -= tickCounter
 			}
@@ -159,7 +159,7 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
 	}
 
 	// a previously unknown peer should be accepted now
-	if pool.connect(poolTestPeer(54321), 0) == nil {
+	if !pool.connect(poolTestPeer(54321), 0) {
 		t.Fatalf("Previously unknown peer rejected")
 	}
 
@@ -173,7 +173,7 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
 		pool.connect(poolTestPeer(i), 0)
 	}
 	// expect pool to remember known nodes and kick out one of them to accept a new one
-	if pool.connect(poolTestPeer(54322), 0) == nil {
+	if !pool.connect(poolTestPeer(54322), 0) {
 		t.Errorf("Previously unknown peer rejected after restarting pool")
 	}
 	pool.stop()

+ 5 - 5
les/peer.go

@@ -94,6 +94,7 @@ type peer struct {
 	sendQueue *execQueue
 
 	errCh chan error
+
 	// responseLock ensures that responses are queued in the same order as
 	// RequestProcessed is called
 	responseLock  sync.Mutex
@@ -107,11 +108,10 @@ type peer struct {
 	updateTime     mclock.AbsTime
 	frozen         uint32 // 1 if client is in frozen state
 
-	fcClient       *flowcontrol.ClientNode // nil if the peer is server only
-	fcServer       *flowcontrol.ServerNode // nil if the peer is client only
-	fcParams       flowcontrol.ServerParams
-	fcCosts        requestCostTable
-	balanceTracker *balanceTracker // set by clientPool.connect, used and removed by serverHandler.
+	fcClient *flowcontrol.ClientNode // nil if the peer is server only
+	fcServer *flowcontrol.ServerNode // nil if the peer is client only
+	fcParams flowcontrol.ServerParams
+	fcCosts  requestCostTable
 
 	trusted                 bool
 	onlyAnnounce            bool

+ 0 - 2
les/server.go

@@ -112,9 +112,7 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
 		maxCapacity = totalRecharge
 	}
 	srv.fcManager.SetCapacityLimits(srv.freeCapacity, maxCapacity, srv.freeCapacity*2)
-
 	srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, 10000, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) })
-	srv.peers.notify(srv.clientPool)
 
 	checkpoint := srv.latestLocalCheckpoint()
 	if !checkpoint.Empty() {

+ 32 - 10
les/server_handler.go

@@ -54,7 +54,10 @@ const (
 	MaxTxStatus              = 256 // Amount of transactions to queried per request
 )
 
-var errTooManyInvalidRequest = errors.New("too many invalid requests made")
+var (
+	errTooManyInvalidRequest = errors.New("too many invalid requests made")
+	errFullClientPool        = errors.New("client pool is full")
+)
 
 // serverHandler is responsible for serving light client and process
 // all incoming light requests.
@@ -124,23 +127,26 @@ func (h *serverHandler) handle(p *peer) error {
 	}
 	defer p.fcClient.Disconnect()
 
+	// Disconnect the inbound peer if it's rejected by clientPool
+	if !h.server.clientPool.connect(p, 0) {
+		p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool)
+		return errFullClientPool
+	}
 	// Register the peer locally
 	if err := h.server.peers.Register(p); err != nil {
+		h.server.clientPool.disconnect(p)
 		p.Log().Error("Light Ethereum peer registration failed", "err", err)
 		return err
 	}
 	clientConnectionGauge.Update(int64(h.server.peers.Len()))
 
-	// add dummy balance tracker for tests
-	if p.balanceTracker == nil {
-		p.balanceTracker = &balanceTracker{}
-		p.balanceTracker.init(&mclock.System{}, 1)
-	}
+	var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
 
 	connectedAt := mclock.Now()
 	defer func() {
-		p.balanceTracker = nil
+		wg.Wait() // Ensure all background task routines have exited.
 		h.server.peers.Unregister(p.id)
+		h.server.clientPool.disconnect(p)
 		clientConnectionGauge.Update(int64(h.server.peers.Len()))
 		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
 	}()
@@ -153,7 +159,7 @@ func (h *serverHandler) handle(p *peer) error {
 			return err
 		default:
 		}
-		if err := h.handleMsg(p); err != nil {
+		if err := h.handleMsg(p, &wg); err != nil {
 			p.Log().Debug("Light Ethereum message handling failed", "err", err)
 			return err
 		}
@@ -162,7 +168,7 @@ func (h *serverHandler) handle(p *peer) error {
 
 // handleMsg is invoked whenever an inbound message is received from a remote
 // peer. The remote connection is torn down upon returning any error.
-func (h *serverHandler) handleMsg(p *peer) error {
+func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error {
 	// Read the next message from the remote peer, and ensure it's fully consumed
 	msg, err := p.rw.ReadMsg()
 	if err != nil {
@@ -243,7 +249,7 @@ func (h *serverHandler) handleMsg(p *peer) error {
 			// Feed cost tracker request serving statistic.
 			h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
 			// Reduce priority "balance" for the specific peer.
-			p.balanceTracker.requestCost(realCost)
+			h.server.clientPool.requestCost(p, realCost)
 		}
 		if reply != nil {
 			p.queueSend(func() {
@@ -273,7 +279,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
 		}
 		query := req.Query
 		if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
+			wg.Add(1)
 			go func() {
+				defer wg.Done()
 				hashMode := query.Origin.Hash != (common.Hash{})
 				first := true
 				maxNonCanonical := uint64(100)
@@ -387,7 +395,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
 		)
 		reqCnt := len(req.Hashes)
 		if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
+			wg.Add(1)
 			go func() {
+				defer wg.Done()
 				for i, hash := range req.Hashes {
 					if i != 0 && !task.waitOrStop() {
 						sendResponse(req.ReqID, 0, nil, task.servingTime)
@@ -433,7 +443,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
 		)
 		reqCnt := len(req.Reqs)
 		if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
+			wg.Add(1)
 			go func() {
+				defer wg.Done()
 				for i, request := range req.Reqs {
 					if i != 0 && !task.waitOrStop() {
 						sendResponse(req.ReqID, 0, nil, task.servingTime)
@@ -502,7 +514,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
 		)
 		reqCnt := len(req.Hashes)
 		if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
+			wg.Add(1)
 			go func() {
+				defer wg.Done()
 				for i, hash := range req.Hashes {
 					if i != 0 && !task.waitOrStop() {
 						sendResponse(req.ReqID, 0, nil, task.servingTime)
@@ -557,7 +571,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
 		)
 		reqCnt := len(req.Reqs)
 		if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
+			wg.Add(1)
 			go func() {
+				defer wg.Done()
 				nodes := light.NewNodeSet()
 
 				for i, request := range req.Reqs {
@@ -658,7 +674,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
 		)
 		reqCnt := len(req.Reqs)
 		if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
+			wg.Add(1)
 			go func() {
+				defer wg.Done()
 				var (
 					lastIdx  uint64
 					lastType uint
@@ -725,7 +743,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
 		}
 		reqCnt := len(req.Txs)
 		if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
+			wg.Add(1)
 			go func() {
+				defer wg.Done()
 				stats := make([]light.TxStatus, len(req.Txs))
 				for i, tx := range req.Txs {
 					if i != 0 && !task.waitOrStop() {
@@ -771,7 +791,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
 		}
 		reqCnt := len(req.Hashes)
 		if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
+			wg.Add(1)
 			go func() {
+				defer wg.Done()
 				stats := make([]light.TxStatus, len(req.Hashes))
 				for i, hash := range req.Hashes {
 					if i != 0 && !task.waitOrStop() {

+ 2 - 0
les/test_helper.go

@@ -280,6 +280,8 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da
 	}
 	server.costTracker, server.freeCapacity = newCostTracker(db, server.config)
 	server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism.
+	server.clientPool = newClientPool(db, 1, 10000, clock, nil)
+	server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool
 	server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true })
 	if server.oracle != nil {
 		server.oracle.start(simulation)