瀏覽代碼

les: clean up server handler (#22357)

gary rong 4 年之前
父節點
當前提交
3ecfdccd9a
共有 2 個文件被更改,包括 114 次插入87 次删除
  1. 106 79
      les/server_handler.go
  2. 8 8
      les/server_requests.go

+ 106 - 79
les/server_handler.go

@@ -204,6 +204,90 @@ func (h *serverHandler) handle(p *clientPeer) error {
 	}
 }
 
+// beforeHandle will do a series of prechecks before handling message.
+func (h *serverHandler) beforeHandle(p *clientPeer, reqID, responseCount uint64, msg p2p.Msg, reqCnt uint64, maxCount uint64) (*servingTask, uint64) {
+	// Ensure that the request sent by client peer is valid
+	inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
+	if reqCnt == 0 || reqCnt > maxCount {
+		p.fcClient.OneTimeCost(inSizeCost)
+		return nil, 0
+	}
+	// Ensure that the client peer complies with the flow control
+	// rules agreed by both sides.
+	if p.isFrozen() {
+		p.fcClient.OneTimeCost(inSizeCost)
+		return nil, 0
+	}
+	maxCost := p.fcCosts.getMaxCost(msg.Code, reqCnt)
+	accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
+	if !accepted {
+		p.freeze()
+		p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
+		p.fcClient.OneTimeCost(inSizeCost)
+		return nil, 0
+	}
+	// Create a multi-stage task, estimate the time it takes for the task to
+	// execute, and cache it in the request service queue.
+	factor := h.server.costTracker.globalFactor()
+	if factor < 0.001 {
+		factor = 1
+		p.Log().Error("Invalid global cost factor", "factor", factor)
+	}
+	maxTime := uint64(float64(maxCost) / factor)
+	task := h.server.servingQueue.newTask(p, maxTime, priority)
+	if !task.start() {
+		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
+		return nil, 0
+	}
+	return task, maxCost
+}
+
+// Afterhandle will perform a series of operations after message handling,
+// such as updating flow control data, sending reply, etc.
+func (h *serverHandler) afterHandle(p *clientPeer, reqID, responseCount uint64, msg p2p.Msg, maxCost uint64, reqCnt uint64, task *servingTask, reply *reply) {
+	if reply != nil {
+		task.done()
+	}
+	p.responseLock.Lock()
+	defer p.responseLock.Unlock()
+
+	// Short circuit if the client is already frozen.
+	if p.isFrozen() {
+		realCost := h.server.costTracker.realCost(task.servingTime, msg.Size, 0)
+		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
+		return
+	}
+	// Positive correction buffer value with real cost.
+	var replySize uint32
+	if reply != nil {
+		replySize = reply.size()
+	}
+	var realCost uint64
+	if h.server.costTracker.testing {
+		realCost = maxCost // Assign a fake cost for testing purpose
+	} else {
+		realCost = h.server.costTracker.realCost(task.servingTime, msg.Size, replySize)
+		if realCost > maxCost {
+			realCost = maxCost
+		}
+	}
+	bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
+	if reply != nil {
+		// Feed cost tracker request serving statistic.
+		h.server.costTracker.updateStats(msg.Code, reqCnt, task.servingTime, realCost)
+		// Reduce priority "balance" for the specific peer.
+		p.balance.RequestServed(realCost)
+		p.queueSend(func() {
+			if err := reply.send(bv); err != nil {
+				select {
+				case p.errCh <- err:
+				default:
+				}
+			}
+		})
+	}
+}
+
 // handleMsg is invoked whenever an inbound message is received from a remote
 // peer. The remote connection is torn down upon returning any error.
 func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
@@ -221,9 +305,8 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
 	}
 	defer msg.Discard()
 
-	p.responseCount++
-	responseCount := p.responseCount
-
+	// Lookup the request handler table, ensure it's supported
+	// message type by the protocol.
 	req, ok := Les3[msg.Code]
 	if !ok {
 		p.Log().Trace("Received invalid message", "code", msg.Code)
@@ -232,98 +315,42 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error {
 	}
 	p.Log().Trace("Received " + req.Name)
 
+	// Decode the p2p message, resolve the concrete handler for it.
 	serve, reqID, reqCnt, err := req.Handle(msg)
 	if err != nil {
 		clientErrorMeter.Mark(1)
 		return errResp(ErrDecode, "%v: %v", msg, err)
 	}
-
 	if metrics.EnabledExpensive {
 		req.InPacketsMeter.Mark(1)
 		req.InTrafficMeter.Mark(int64(msg.Size))
 	}
+	p.responseCount++
+	responseCount := p.responseCount
 
-	// Short circuit if the peer is already frozen or the request is invalid.
-	inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0)
-	if p.isFrozen() || reqCnt == 0 || reqCnt > req.MaxCount {
-		p.fcClient.OneTimeCost(inSizeCost)
-		return nil
-	}
-	// Prepaid max cost units before request been serving.
-	maxCost := p.fcCosts.getMaxCost(msg.Code, reqCnt)
-	accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost)
-	if !accepted {
-		p.freeze()
-		p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge)))
-		p.fcClient.OneTimeCost(inSizeCost)
+	// First check this client message complies all rules before
+	// handling it and return a processor if all checks are passed.
+	task, maxCost := h.beforeHandle(p, reqID, responseCount, msg, reqCnt, req.MaxCount)
+	if task == nil {
 		return nil
 	}
-	// Create a multi-stage task, estimate the time it takes for the task to
-	// execute, and cache it in the request service queue.
-	factor := h.server.costTracker.globalFactor()
-	if factor < 0.001 {
-		factor = 1
-		p.Log().Error("Invalid global cost factor", "factor", factor)
-	}
-	maxTime := uint64(float64(maxCost) / factor)
-	task := h.server.servingQueue.newTask(p, maxTime, priority)
-	if task.start() {
-		wg.Add(1)
-		go func() {
-			defer wg.Done()
-			reply := serve(h, p, task.waitOrStop)
-			if reply != nil {
-				task.done()
-			}
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
 
-			p.responseLock.Lock()
-			defer p.responseLock.Unlock()
+		reply := serve(h, p, task.waitOrStop)
+		h.afterHandle(p, reqID, responseCount, msg, maxCost, reqCnt, task, reply)
 
-			// Short circuit if the client is already frozen.
-			if p.isFrozen() {
-				realCost := h.server.costTracker.realCost(task.servingTime, msg.Size, 0)
-				p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
-				return
-			}
-			// Positive correction buffer value with real cost.
-			var replySize uint32
+		if metrics.EnabledExpensive {
+			size := uint32(0)
 			if reply != nil {
-				replySize = reply.size()
-			}
-			var realCost uint64
-			if h.server.costTracker.testing {
-				realCost = maxCost // Assign a fake cost for testing purpose
-			} else {
-				realCost = h.server.costTracker.realCost(task.servingTime, msg.Size, replySize)
-				if realCost > maxCost {
-					realCost = maxCost
-				}
+				size = reply.size()
 			}
-			bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost)
-			if reply != nil {
-				// Feed cost tracker request serving statistic.
-				h.server.costTracker.updateStats(msg.Code, reqCnt, task.servingTime, realCost)
-				// Reduce priority "balance" for the specific peer.
-				p.balance.RequestServed(realCost)
-				p.queueSend(func() {
-					if err := reply.send(bv); err != nil {
-						select {
-						case p.errCh <- err:
-						default:
-						}
-					}
-				})
-				if metrics.EnabledExpensive {
-					req.OutPacketsMeter.Mark(1)
-					req.OutTrafficMeter.Mark(int64(replySize))
-					req.ServingTimeMeter.Update(time.Duration(task.servingTime))
-				}
-			}
-		}()
-	} else {
-		p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost)
-	}
-
+			req.OutPacketsMeter.Mark(1)
+			req.OutTrafficMeter.Mark(int64(size))
+			req.ServingTimeMeter.Update(time.Duration(task.servingTime))
+		}
+	}()
 	// If the client has made too much invalid request(e.g. request a non-existent data),
 	// reject them to prevent SPAM attack.
 	if p.getInvalid() > maxRequestErrors {

+ 8 - 8
les/server_requests.go

@@ -65,7 +65,7 @@ type serveRequestFn func(backend serverBackend, peer *clientPeer, waitOrStop fun
 
 // Les3 contains the request types supported by les/2 and les/3
 var Les3 = map[uint64]RequestType{
-	GetBlockHeadersMsg: RequestType{
+	GetBlockHeadersMsg: {
 		Name:             "block header request",
 		MaxCount:         MaxHeaderFetch,
 		InPacketsMeter:   miscInHeaderPacketsMeter,
@@ -75,7 +75,7 @@ var Les3 = map[uint64]RequestType{
 		ServingTimeMeter: miscServingTimeHeaderTimer,
 		Handle:           handleGetBlockHeaders,
 	},
-	GetBlockBodiesMsg: RequestType{
+	GetBlockBodiesMsg: {
 		Name:             "block bodies request",
 		MaxCount:         MaxBodyFetch,
 		InPacketsMeter:   miscInBodyPacketsMeter,
@@ -85,7 +85,7 @@ var Les3 = map[uint64]RequestType{
 		ServingTimeMeter: miscServingTimeBodyTimer,
 		Handle:           handleGetBlockBodies,
 	},
-	GetCodeMsg: RequestType{
+	GetCodeMsg: {
 		Name:             "code request",
 		MaxCount:         MaxCodeFetch,
 		InPacketsMeter:   miscInCodePacketsMeter,
@@ -95,7 +95,7 @@ var Les3 = map[uint64]RequestType{
 		ServingTimeMeter: miscServingTimeCodeTimer,
 		Handle:           handleGetCode,
 	},
-	GetReceiptsMsg: RequestType{
+	GetReceiptsMsg: {
 		Name:             "receipts request",
 		MaxCount:         MaxReceiptFetch,
 		InPacketsMeter:   miscInReceiptPacketsMeter,
@@ -105,7 +105,7 @@ var Les3 = map[uint64]RequestType{
 		ServingTimeMeter: miscServingTimeReceiptTimer,
 		Handle:           handleGetReceipts,
 	},
-	GetProofsV2Msg: RequestType{
+	GetProofsV2Msg: {
 		Name:             "les/2 proofs request",
 		MaxCount:         MaxProofsFetch,
 		InPacketsMeter:   miscInTrieProofPacketsMeter,
@@ -115,7 +115,7 @@ var Les3 = map[uint64]RequestType{
 		ServingTimeMeter: miscServingTimeTrieProofTimer,
 		Handle:           handleGetProofs,
 	},
-	GetHelperTrieProofsMsg: RequestType{
+	GetHelperTrieProofsMsg: {
 		Name:             "helper trie proof request",
 		MaxCount:         MaxHelperTrieProofsFetch,
 		InPacketsMeter:   miscInHelperTriePacketsMeter,
@@ -125,7 +125,7 @@ var Les3 = map[uint64]RequestType{
 		ServingTimeMeter: miscServingTimeHelperTrieTimer,
 		Handle:           handleGetHelperTrieProofs,
 	},
-	SendTxV2Msg: RequestType{
+	SendTxV2Msg: {
 		Name:             "new transactions",
 		MaxCount:         MaxTxSend,
 		InPacketsMeter:   miscInTxsPacketsMeter,
@@ -135,7 +135,7 @@ var Les3 = map[uint64]RequestType{
 		ServingTimeMeter: miscServingTimeTxTimer,
 		Handle:           handleSendTx,
 	},
-	GetTxStatusMsg: RequestType{
+	GetTxStatusMsg: {
 		Name:             "transaction status query request",
 		MaxCount:         MaxTxStatus,
 		InPacketsMeter:   miscInTxStatusPacketsMeter,