|
|
@@ -0,0 +1,405 @@
|
|
|
+// Copyright 2020 The go-ethereum Authors
|
|
|
+// This file is part of the go-ethereum library.
|
|
|
+//
|
|
|
+// The go-ethereum library is free software: you can redistribute it and/or modify
|
|
|
+// it under the terms of the GNU Lesser General Public License as published by
|
|
|
+// the Free Software Foundation, either version 3 of the License, or
|
|
|
+// (at your option) any later version.
|
|
|
+//
|
|
|
+// The go-ethereum library is distributed in the hope that it will be useful,
|
|
|
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
+// GNU Lesser General Public License for more details.
|
|
|
+//
|
|
|
+// You should have received a copy of the GNU Lesser General Public License
|
|
|
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
|
|
+
|
|
|
+package utils
|
|
|
+
|
|
|
+import (
|
|
|
+ "sort"
|
|
|
+ "sync"
|
|
|
+
|
|
|
+ "github.com/ethereum/go-ethereum/p2p/enode"
|
|
|
+)
|
|
|
+
|
|
|
+const maxSelectionWeight = 1000000000 // maximum selection weight of each individual node/address group
|
|
|
+
|
|
|
+// Limiter protects a network request serving mechanism from denial-of-service attacks.
|
|
|
+// It limits the total amount of resources used for serving requests while ensuring that
|
|
|
+// the most valuable connections always have a reasonable chance of being served.
|
|
|
+type Limiter struct {
|
|
|
+ lock sync.Mutex
|
|
|
+ cond *sync.Cond
|
|
|
+ quit bool
|
|
|
+
|
|
|
+ nodes map[enode.ID]*nodeQueue
|
|
|
+ addresses map[string]*addressGroup
|
|
|
+ addressSelect, valueSelect *WeightedRandomSelect
|
|
|
+ maxValue float64
|
|
|
+ maxCost, sumCost, sumCostLimit uint
|
|
|
+ selectAddressNext bool
|
|
|
+}
|
|
|
+
|
|
|
+// nodeQueue represents queued requests coming from a single node ID
|
|
|
+type nodeQueue struct {
|
|
|
+ queue []request // always nil if penaltyCost != 0
|
|
|
+ id enode.ID
|
|
|
+ address string
|
|
|
+ value float64
|
|
|
+ flatWeight, valueWeight uint64 // current selection weights in the address/value selectors
|
|
|
+ sumCost uint // summed cost of requests queued by the node
|
|
|
+ penaltyCost uint // cumulative cost of dropped requests since last processed request
|
|
|
+ groupIndex int
|
|
|
+}
|
|
|
+
|
|
|
+// addressGroup is a group of node IDs that have sent their last requests from the same
|
|
|
+// network address
|
|
|
+type addressGroup struct {
|
|
|
+ nodes []*nodeQueue
|
|
|
+ nodeSelect *WeightedRandomSelect
|
|
|
+ sumFlatWeight, groupWeight uint64
|
|
|
+}
|
|
|
+
|
|
|
+// request represents an incoming request scheduled for processing
|
|
|
+type request struct {
|
|
|
+ process chan chan struct{}
|
|
|
+ cost uint
|
|
|
+}
|
|
|
+
|
|
|
+// flatWeight distributes weights equally between each active network address
|
|
|
+func flatWeight(item interface{}) uint64 { return item.(*nodeQueue).flatWeight }
|
|
|
+
|
|
|
+// add adds the node queue to the address group. It is the caller's responsibility to
|
|
|
+// add the address group to the address map and the address selector if it wasn't
|
|
|
+// there before.
|
|
|
+func (ag *addressGroup) add(nq *nodeQueue) {
|
|
|
+ if nq.groupIndex != -1 {
|
|
|
+ panic("added node queue is already in an address group")
|
|
|
+ }
|
|
|
+ l := len(ag.nodes)
|
|
|
+ nq.groupIndex = l
|
|
|
+ ag.nodes = append(ag.nodes, nq)
|
|
|
+ ag.sumFlatWeight += nq.flatWeight
|
|
|
+ ag.groupWeight = ag.sumFlatWeight / uint64(l+1)
|
|
|
+ ag.nodeSelect.Update(ag.nodes[l])
|
|
|
+}
|
|
|
+
|
|
|
+// update updates the selection weight of the node queue inside the address group.
|
|
|
+// It is the caller's responsibility to update the group's selection weight in the
|
|
|
+// address selector.
|
|
|
+func (ag *addressGroup) update(nq *nodeQueue, weight uint64) {
|
|
|
+ if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
|
|
|
+ panic("updated node queue is not in this address group")
|
|
|
+ }
|
|
|
+ ag.sumFlatWeight += weight - nq.flatWeight
|
|
|
+ nq.flatWeight = weight
|
|
|
+ ag.groupWeight = ag.sumFlatWeight / uint64(len(ag.nodes))
|
|
|
+ ag.nodeSelect.Update(nq)
|
|
|
+}
|
|
|
+
|
|
|
+// remove removes the node queue from the address group. It is the caller's responsibility
|
|
|
+// to remove the address group from the address map if it is empty.
|
|
|
+func (ag *addressGroup) remove(nq *nodeQueue) {
|
|
|
+ if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
|
|
|
+ panic("removed node queue is not in this address group")
|
|
|
+ }
|
|
|
+
|
|
|
+ l := len(ag.nodes) - 1
|
|
|
+ if nq.groupIndex != l {
|
|
|
+ ag.nodes[nq.groupIndex] = ag.nodes[l]
|
|
|
+ ag.nodes[nq.groupIndex].groupIndex = nq.groupIndex
|
|
|
+ }
|
|
|
+ nq.groupIndex = -1
|
|
|
+ ag.nodes = ag.nodes[:l]
|
|
|
+ ag.sumFlatWeight -= nq.flatWeight
|
|
|
+ if l >= 1 {
|
|
|
+ ag.groupWeight = ag.sumFlatWeight / uint64(l)
|
|
|
+ } else {
|
|
|
+ ag.groupWeight = 0
|
|
|
+ }
|
|
|
+ ag.nodeSelect.Remove(nq)
|
|
|
+}
|
|
|
+
|
|
|
+// choose selects one of the node queues belonging to the address group
|
|
|
+func (ag *addressGroup) choose() *nodeQueue {
|
|
|
+ return ag.nodeSelect.Choose().(*nodeQueue)
|
|
|
+}
|
|
|
+
|
|
|
+// NewLimiter creates a new Limiter
|
|
|
+func NewLimiter(sumCostLimit uint) *Limiter {
|
|
|
+ l := &Limiter{
|
|
|
+ addressSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*addressGroup).groupWeight }),
|
|
|
+ valueSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*nodeQueue).valueWeight }),
|
|
|
+ nodes: make(map[enode.ID]*nodeQueue),
|
|
|
+ addresses: make(map[string]*addressGroup),
|
|
|
+ sumCostLimit: sumCostLimit,
|
|
|
+ }
|
|
|
+ l.cond = sync.NewCond(&l.lock)
|
|
|
+ go l.processLoop()
|
|
|
+ return l
|
|
|
+}
|
|
|
+
|
|
|
+// selectionWeights calculates the selection weights of a node for both the address and
|
|
|
+// the value selector. The selection weight depends on the next request cost or the
|
|
|
+// summed cost of recently dropped requests.
|
|
|
+func (l *Limiter) selectionWeights(reqCost uint, value float64) (flatWeight, valueWeight uint64) {
|
|
|
+ if value > l.maxValue {
|
|
|
+ l.maxValue = value
|
|
|
+ }
|
|
|
+ if value > 0 {
|
|
|
+ // normalize value to <= 1
|
|
|
+ value /= l.maxValue
|
|
|
+ }
|
|
|
+ if reqCost > l.maxCost {
|
|
|
+ l.maxCost = reqCost
|
|
|
+ }
|
|
|
+ relCost := float64(reqCost) / float64(l.maxCost)
|
|
|
+ var f float64
|
|
|
+ if relCost <= 0.001 {
|
|
|
+ f = 1
|
|
|
+ } else {
|
|
|
+ f = 0.001 / relCost
|
|
|
+ }
|
|
|
+ f *= maxSelectionWeight
|
|
|
+ flatWeight, valueWeight = uint64(f), uint64(f*value)
|
|
|
+ if flatWeight == 0 {
|
|
|
+ flatWeight = 1
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+// Add adds a new request to the node queue belonging to the given id. Value belongs
|
|
|
+// to the requesting node. A higher value gives the request a higher chance of being
|
|
|
+// served quickly in case of heavy load or a DDoS attack. Cost is a rough estimate
|
|
|
+// of the serving cost of the request. A lower cost also gives the request a
|
|
|
+// better chance.
|
|
|
+func (l *Limiter) Add(id enode.ID, address string, value float64, reqCost uint) chan chan struct{} {
|
|
|
+ l.lock.Lock()
|
|
|
+ defer l.lock.Unlock()
|
|
|
+
|
|
|
+ process := make(chan chan struct{}, 1)
|
|
|
+ if l.quit {
|
|
|
+ close(process)
|
|
|
+ return process
|
|
|
+ }
|
|
|
+ if reqCost == 0 {
|
|
|
+ reqCost = 1
|
|
|
+ }
|
|
|
+ if nq, ok := l.nodes[id]; ok {
|
|
|
+ if nq.queue != nil {
|
|
|
+ nq.queue = append(nq.queue, request{process, reqCost})
|
|
|
+ nq.sumCost += reqCost
|
|
|
+ nq.value = value
|
|
|
+ if address != nq.address {
|
|
|
+ // known id sending request from a new address, move to different address group
|
|
|
+ l.removeFromGroup(nq)
|
|
|
+ l.addToGroup(nq, address)
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // already waiting on a penalty, just add to the penalty cost and drop the request
|
|
|
+ nq.penaltyCost += reqCost
|
|
|
+ l.update(nq)
|
|
|
+ close(process)
|
|
|
+ return process
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ nq := &nodeQueue{
|
|
|
+ queue: []request{{process, reqCost}},
|
|
|
+ id: id,
|
|
|
+ value: value,
|
|
|
+ sumCost: reqCost,
|
|
|
+ groupIndex: -1,
|
|
|
+ }
|
|
|
+ nq.flatWeight, nq.valueWeight = l.selectionWeights(reqCost, value)
|
|
|
+ if len(l.nodes) == 0 {
|
|
|
+ l.cond.Signal()
|
|
|
+ }
|
|
|
+ l.nodes[id] = nq
|
|
|
+ if nq.valueWeight != 0 {
|
|
|
+ l.valueSelect.Update(nq)
|
|
|
+ }
|
|
|
+ l.addToGroup(nq, address)
|
|
|
+ }
|
|
|
+ l.sumCost += reqCost
|
|
|
+ if l.sumCost > l.sumCostLimit {
|
|
|
+ l.dropRequests()
|
|
|
+ }
|
|
|
+ return process
|
|
|
+}
|
|
|
+
|
|
|
+// update updates the selection weights of the node queue
|
|
|
+func (l *Limiter) update(nq *nodeQueue) {
|
|
|
+ var cost uint
|
|
|
+ if nq.queue != nil {
|
|
|
+ cost = nq.queue[0].cost
|
|
|
+ } else {
|
|
|
+ cost = nq.penaltyCost
|
|
|
+ }
|
|
|
+ flatWeight, valueWeight := l.selectionWeights(cost, nq.value)
|
|
|
+ ag := l.addresses[nq.address]
|
|
|
+ ag.update(nq, flatWeight)
|
|
|
+ l.addressSelect.Update(ag)
|
|
|
+ nq.valueWeight = valueWeight
|
|
|
+ l.valueSelect.Update(nq)
|
|
|
+}
|
|
|
+
|
|
|
+// addToGroup adds the node queue to the given address group. The group is created if
|
|
|
+// it does not exist yet.
|
|
|
+func (l *Limiter) addToGroup(nq *nodeQueue, address string) {
|
|
|
+ nq.address = address
|
|
|
+ ag := l.addresses[address]
|
|
|
+ if ag == nil {
|
|
|
+ ag = &addressGroup{nodeSelect: NewWeightedRandomSelect(flatWeight)}
|
|
|
+ l.addresses[address] = ag
|
|
|
+ }
|
|
|
+ ag.add(nq)
|
|
|
+ l.addressSelect.Update(ag)
|
|
|
+}
|
|
|
+
|
|
|
+// removeFromGroup removes the node queue from its address group
|
|
|
+func (l *Limiter) removeFromGroup(nq *nodeQueue) {
|
|
|
+ ag := l.addresses[nq.address]
|
|
|
+ ag.remove(nq)
|
|
|
+ if len(ag.nodes) == 0 {
|
|
|
+ delete(l.addresses, nq.address)
|
|
|
+ }
|
|
|
+ l.addressSelect.Update(ag)
|
|
|
+}
|
|
|
+
|
|
|
+// remove removes the node queue from its address group, the nodes map and the value
|
|
|
+// selector
|
|
|
+func (l *Limiter) remove(nq *nodeQueue) {
|
|
|
+ l.removeFromGroup(nq)
|
|
|
+ if nq.valueWeight != 0 {
|
|
|
+ l.valueSelect.Remove(nq)
|
|
|
+ }
|
|
|
+ delete(l.nodes, nq.id)
|
|
|
+}
|
|
|
+
|
|
|
+// choose selects the next node queue to process.
|
|
|
+func (l *Limiter) choose() *nodeQueue {
|
|
|
+ if l.valueSelect.IsEmpty() || l.selectAddressNext {
|
|
|
+ if ag, ok := l.addressSelect.Choose().(*addressGroup); ok {
|
|
|
+ l.selectAddressNext = false
|
|
|
+ return ag.choose()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ nq, _ := l.valueSelect.Choose().(*nodeQueue)
|
|
|
+ l.selectAddressNext = true
|
|
|
+ return nq
|
|
|
+}
|
|
|
+
|
|
|
+// processLoop processes requests sequentially
|
|
|
+func (l *Limiter) processLoop() {
|
|
|
+ l.lock.Lock()
|
|
|
+ defer l.lock.Unlock()
|
|
|
+
|
|
|
+ for {
|
|
|
+ if l.quit {
|
|
|
+ for _, nq := range l.nodes {
|
|
|
+ for _, request := range nq.queue {
|
|
|
+ close(request.process)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+ nq := l.choose()
|
|
|
+ if nq == nil {
|
|
|
+ l.cond.Wait()
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if nq.queue != nil {
|
|
|
+ request := nq.queue[0]
|
|
|
+ nq.queue = nq.queue[1:]
|
|
|
+ nq.sumCost -= request.cost
|
|
|
+ l.sumCost -= request.cost
|
|
|
+ l.lock.Unlock()
|
|
|
+ ch := make(chan struct{})
|
|
|
+ request.process <- ch
|
|
|
+ <-ch
|
|
|
+ l.lock.Lock()
|
|
|
+ if len(nq.queue) > 0 {
|
|
|
+ l.update(nq)
|
|
|
+ } else {
|
|
|
+ l.remove(nq)
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // penalized queue removed, next request will be added to a clean queue
|
|
|
+ l.remove(nq)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Stop stops the processing loop. All queued and future requests are rejected.
|
|
|
+func (l *Limiter) Stop() {
|
|
|
+ l.lock.Lock()
|
|
|
+ defer l.lock.Unlock()
|
|
|
+
|
|
|
+ l.quit = true
|
|
|
+ l.cond.Signal()
|
|
|
+}
|
|
|
+
|
|
|
+type (
|
|
|
+ dropList []dropListItem
|
|
|
+ dropListItem struct {
|
|
|
+ nq *nodeQueue
|
|
|
+ priority float64
|
|
|
+ }
|
|
|
+)
|
|
|
+
|
|
|
+func (l dropList) Len() int {
|
|
|
+ return len(l)
|
|
|
+}
|
|
|
+
|
|
|
+func (l dropList) Less(i, j int) bool {
|
|
|
+ return l[i].priority < l[j].priority
|
|
|
+}
|
|
|
+
|
|
|
+func (l dropList) Swap(i, j int) {
|
|
|
+ l[i], l[j] = l[j], l[i]
|
|
|
+}
|
|
|
+
|
|
|
+// dropRequests selects the nodes with the highest queued request cost to selection
|
|
|
+// weight ratio and drops their queued request. The empty node queues stay in the
|
|
|
+// selectors with a low selection weight in order to penalize these nodes.
|
|
|
+func (l *Limiter) dropRequests() {
|
|
|
+ var (
|
|
|
+ sumValue float64
|
|
|
+ list dropList
|
|
|
+ )
|
|
|
+ for _, nq := range l.nodes {
|
|
|
+ sumValue += nq.value
|
|
|
+ }
|
|
|
+ for _, nq := range l.nodes {
|
|
|
+ if nq.sumCost == 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ w := 1 / float64(len(l.addresses)*len(l.addresses[nq.address].nodes))
|
|
|
+ if sumValue > 0 {
|
|
|
+ w += nq.value / sumValue
|
|
|
+ }
|
|
|
+ list = append(list, dropListItem{
|
|
|
+ nq: nq,
|
|
|
+ priority: w / float64(nq.sumCost),
|
|
|
+ })
|
|
|
+ }
|
|
|
+ sort.Sort(list)
|
|
|
+ for _, item := range list {
|
|
|
+ for _, request := range item.nq.queue {
|
|
|
+ close(request.process)
|
|
|
+ }
|
|
|
+ // make the queue penalized; no more requests are accepted until the node is
|
|
|
+ // selected based on the penalty cost which is the cumulative cost of all dropped
|
|
|
+ // requests. This ensures that sending excess requests is always penalized
|
|
|
+ // and incentivizes the sender to stop for a while if no replies are received.
|
|
|
+ item.nq.queue = nil
|
|
|
+ item.nq.penaltyCost = item.nq.sumCost
|
|
|
+ l.sumCost -= item.nq.sumCost // penalty costs are not counted in sumCost
|
|
|
+ item.nq.sumCost = 0
|
|
|
+ l.update(item.nq)
|
|
|
+ if l.sumCost <= l.sumCostLimit/2 {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|