Browse Source

les: check required message types in cost table (#19454)

Felföldi Zsolt 6 years ago
parent
commit
85b6823d16
2 changed files with 20 additions and 5 deletions
  1. 18 3
      les/peer.go
  2. 2 2
      les/peer_test.go

+ 18 - 3
les/peer.go

@@ -232,7 +232,11 @@ func (p *peer) GetRequestCost(msgcode uint64, amount int) uint64 {
 	p.lock.RLock()
 	defer p.lock.RUnlock()
 
-	cost := p.fcCosts[msgcode].baseCost + p.fcCosts[msgcode].reqCost*uint64(amount)
+	costs := p.fcCosts[msgcode]
+	if costs == nil {
+		return 0
+	}
+	cost := costs.baseCost + costs.reqCost*uint64(amount)
 	if cost > p.fcParams.BufLimit {
 		cost = p.fcParams.BufLimit
 	}
@@ -243,8 +247,12 @@ func (p *peer) GetTxRelayCost(amount, size int) uint64 {
 	p.lock.RLock()
 	defer p.lock.RUnlock()
 
-	cost := p.fcCosts[SendTxV2Msg].baseCost + p.fcCosts[SendTxV2Msg].reqCost*uint64(amount)
-	sizeCost := p.fcCosts[SendTxV2Msg].baseCost + p.fcCosts[SendTxV2Msg].reqCost*uint64(size)/txSizeCostLimit
+	costs := p.fcCosts[SendTxV2Msg]
+	if costs == nil {
+		return 0
+	}
+	cost := costs.baseCost + costs.reqCost*uint64(amount)
+	sizeCost := costs.baseCost + costs.reqCost*uint64(size)/txSizeCostLimit
 	if sizeCost > cost {
 		cost = sizeCost
 	}
@@ -564,6 +572,13 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
 		p.fcParams = params
 		p.fcServer = flowcontrol.NewServerNode(params, &mclock.System{})
 		p.fcCosts = MRC.decode()
+		if !p.isOnlyAnnounce {
+			for msgCode := range reqAvgTimeCost {
+				if p.fcCosts[msgCode] == nil {
+					return errResp(ErrUselessPeer, "peer does not support message %d", msgCode)
+				}
+			}
+		}
 	}
 	p.headInfo = &announceData{Td: rTd, Hash: rHash, Number: rNum}
 	return nil

+ 2 - 2
les/peer_test.go

@@ -54,7 +54,7 @@ func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testi
 				l = l.add("txRelay", nil)
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/MRR", uint64(0))
-				l = l.add("flowControl/MRC", RequestCostList{})
+				l = l.add("flowControl/MRC", testCostList())
 
 				return l
 			},
@@ -99,7 +99,7 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi
 				l = l.add("txRelay", nil)
 				l = l.add("flowControl/BL", uint64(0))
 				l = l.add("flowControl/MRR", uint64(0))
-				l = l.add("flowControl/MRC", RequestCostList{})
+				l = l.add("flowControl/MRC", testCostList())
 
 				return l
 			},