Pārlūkot izejas kodu

les/utils: UDP rate limiter (#21930)

* les/utils: Limiter

* les/utils: dropped prior weight vs variable cost logic, using fixed weights

* les/utils: always create node selector in addressGroup

* les/utils: renamed request weight to request cost

* les/utils: simplified and improved the DoS penalty mechanism

* les/utils: minor fixes

* les/utils: made selection weight calculation nicer

* les/utils: fixed linter warning

* les/utils: more precise and reliable probabilistic test

* les/utils: fixed linter warning
Felföldi Zsolt 4 gadi atpakaļ
vecāks
revīzija
7a800f98f6
3 mainītis faili ar 625 papildinājumiem un 14 dzēšanām
  1. 405 0
      les/utils/limiter.go
  2. 206 0
      les/utils/limiter_test.go
  3. 14 14
      les/utils/weighted_select.go

+ 405 - 0
les/utils/limiter.go

@@ -0,0 +1,405 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package utils
+
+import (
+	"sort"
+	"sync"
+
+	"github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+const maxSelectionWeight = 1000000000 // maximum selection weight of each individual node/address group
+
+// Limiter protects a network request serving mechanism from denial-of-service attacks.
+// It limits the total amount of resources used for serving requests while ensuring that
+// the most valuable connections always have a reasonable chance of being served.
+type Limiter struct {
+	lock sync.Mutex
+	cond *sync.Cond
+	quit bool
+
+	nodes                          map[enode.ID]*nodeQueue
+	addresses                      map[string]*addressGroup
+	addressSelect, valueSelect     *WeightedRandomSelect
+	maxValue                       float64
+	maxCost, sumCost, sumCostLimit uint
+	selectAddressNext              bool
+}
+
+// nodeQueue represents queued requests coming from a single node ID
+type nodeQueue struct {
+	queue                   []request // always nil if penaltyCost != 0
+	id                      enode.ID
+	address                 string
+	value                   float64
+	flatWeight, valueWeight uint64 // current selection weights in the address/value selectors
+	sumCost                 uint   // summed cost of requests queued by the node
+	penaltyCost             uint   // cumulative cost of dropped requests since last processed request
+	groupIndex              int
+}
+
+// addressGroup is a group of node IDs that have sent their last requests from the same
+// network address
+type addressGroup struct {
+	nodes                      []*nodeQueue
+	nodeSelect                 *WeightedRandomSelect
+	sumFlatWeight, groupWeight uint64
+}
+
+// request represents an incoming request scheduled for processing
+type request struct {
+	process chan chan struct{}
+	cost    uint
+}
+
+// flatWeight distributes weights equally between each active network address
+func flatWeight(item interface{}) uint64 { return item.(*nodeQueue).flatWeight }
+
+// add adds the node queue to the address group. It is the caller's responsibility to
+// add the address group to the address map and the address selector if it wasn't
+// there before.
+func (ag *addressGroup) add(nq *nodeQueue) {
+	if nq.groupIndex != -1 {
+		panic("added node queue is already in an address group")
+	}
+	l := len(ag.nodes)
+	nq.groupIndex = l
+	ag.nodes = append(ag.nodes, nq)
+	ag.sumFlatWeight += nq.flatWeight
+	ag.groupWeight = ag.sumFlatWeight / uint64(l+1)
+	ag.nodeSelect.Update(ag.nodes[l])
+}
+
+// update updates the selection weight of the node queue inside the address group.
+// It is the caller's responsibility to update the group's selection weight in the
+// address selector.
+func (ag *addressGroup) update(nq *nodeQueue, weight uint64) {
+	if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
+		panic("updated node queue is not in this address group")
+	}
+	ag.sumFlatWeight += weight - nq.flatWeight
+	nq.flatWeight = weight
+	ag.groupWeight = ag.sumFlatWeight / uint64(len(ag.nodes))
+	ag.nodeSelect.Update(nq)
+}
+
+// remove removes the node queue from the address group. It is the caller's responsibility
+// to remove the address group from the address map if it is empty.
+func (ag *addressGroup) remove(nq *nodeQueue) {
+	if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
+		panic("removed node queue is not in this address group")
+	}
+
+	l := len(ag.nodes) - 1
+	if nq.groupIndex != l {
+		ag.nodes[nq.groupIndex] = ag.nodes[l]
+		ag.nodes[nq.groupIndex].groupIndex = nq.groupIndex
+	}
+	nq.groupIndex = -1
+	ag.nodes = ag.nodes[:l]
+	ag.sumFlatWeight -= nq.flatWeight
+	if l >= 1 {
+		ag.groupWeight = ag.sumFlatWeight / uint64(l)
+	} else {
+		ag.groupWeight = 0
+	}
+	ag.nodeSelect.Remove(nq)
+}
+
+// choose selects one of the node queues belonging to the address group
+func (ag *addressGroup) choose() *nodeQueue {
+	return ag.nodeSelect.Choose().(*nodeQueue)
+}
+
+// NewLimiter creates a new Limiter
+func NewLimiter(sumCostLimit uint) *Limiter {
+	l := &Limiter{
+		addressSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*addressGroup).groupWeight }),
+		valueSelect:   NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*nodeQueue).valueWeight }),
+		nodes:         make(map[enode.ID]*nodeQueue),
+		addresses:     make(map[string]*addressGroup),
+		sumCostLimit:  sumCostLimit,
+	}
+	l.cond = sync.NewCond(&l.lock)
+	go l.processLoop()
+	return l
+}
+
+// selectionWeights calculates the selection weights of a node for both the address and
+// the value selector. The selection weight depends on the next request cost or the
+// summed cost of recently dropped requests.
+func (l *Limiter) selectionWeights(reqCost uint, value float64) (flatWeight, valueWeight uint64) {
+	if value > l.maxValue {
+		l.maxValue = value
+	}
+	if value > 0 {
+		// normalize value to <= 1
+		value /= l.maxValue
+	}
+	if reqCost > l.maxCost {
+		l.maxCost = reqCost
+	}
+	relCost := float64(reqCost) / float64(l.maxCost)
+	var f float64
+	if relCost <= 0.001 {
+		f = 1
+	} else {
+		f = 0.001 / relCost
+	}
+	f *= maxSelectionWeight
+	flatWeight, valueWeight = uint64(f), uint64(f*value)
+	if flatWeight == 0 {
+		flatWeight = 1
+	}
+	return
+}
+
+// Add adds a new request to the node queue belonging to the given id. Value belongs
+// to the requesting node. A higher value gives the request a higher chance of being
+// served quickly in case of heavy load or a DDoS attack. Cost is a rough estimate
+// of the serving cost of the request. A lower cost also gives the request a
+// better chance.
+func (l *Limiter) Add(id enode.ID, address string, value float64, reqCost uint) chan chan struct{} {
+	l.lock.Lock()
+	defer l.lock.Unlock()
+
+	process := make(chan chan struct{}, 1)
+	if l.quit {
+		close(process)
+		return process
+	}
+	if reqCost == 0 {
+		reqCost = 1
+	}
+	if nq, ok := l.nodes[id]; ok {
+		if nq.queue != nil {
+			nq.queue = append(nq.queue, request{process, reqCost})
+			nq.sumCost += reqCost
+			nq.value = value
+			if address != nq.address {
+				// known id sending request from a new address, move to different address group
+				l.removeFromGroup(nq)
+				l.addToGroup(nq, address)
+			}
+		} else {
+			// already waiting on a penalty, just add to the penalty cost and drop the request
+			nq.penaltyCost += reqCost
+			l.update(nq)
+			close(process)
+			return process
+		}
+	} else {
+		nq := &nodeQueue{
+			queue:      []request{{process, reqCost}},
+			id:         id,
+			value:      value,
+			sumCost:    reqCost,
+			groupIndex: -1,
+		}
+		nq.flatWeight, nq.valueWeight = l.selectionWeights(reqCost, value)
+		if len(l.nodes) == 0 {
+			l.cond.Signal()
+		}
+		l.nodes[id] = nq
+		if nq.valueWeight != 0 {
+			l.valueSelect.Update(nq)
+		}
+		l.addToGroup(nq, address)
+	}
+	l.sumCost += reqCost
+	if l.sumCost > l.sumCostLimit {
+		l.dropRequests()
+	}
+	return process
+}
+
+// update updates the selection weights of the node queue
+func (l *Limiter) update(nq *nodeQueue) {
+	var cost uint
+	if nq.queue != nil {
+		cost = nq.queue[0].cost
+	} else {
+		cost = nq.penaltyCost
+	}
+	flatWeight, valueWeight := l.selectionWeights(cost, nq.value)
+	ag := l.addresses[nq.address]
+	ag.update(nq, flatWeight)
+	l.addressSelect.Update(ag)
+	nq.valueWeight = valueWeight
+	l.valueSelect.Update(nq)
+}
+
+// addToGroup adds the node queue to the given address group. The group is created if
+// it does not exist yet.
+func (l *Limiter) addToGroup(nq *nodeQueue, address string) {
+	nq.address = address
+	ag := l.addresses[address]
+	if ag == nil {
+		ag = &addressGroup{nodeSelect: NewWeightedRandomSelect(flatWeight)}
+		l.addresses[address] = ag
+	}
+	ag.add(nq)
+	l.addressSelect.Update(ag)
+}
+
+// removeFromGroup removes the node queue from its address group
+func (l *Limiter) removeFromGroup(nq *nodeQueue) {
+	ag := l.addresses[nq.address]
+	ag.remove(nq)
+	if len(ag.nodes) == 0 {
+		delete(l.addresses, nq.address)
+	}
+	l.addressSelect.Update(ag)
+}
+
+// remove removes the node queue from its address group, the nodes map and the value
+// selector
+func (l *Limiter) remove(nq *nodeQueue) {
+	l.removeFromGroup(nq)
+	if nq.valueWeight != 0 {
+		l.valueSelect.Remove(nq)
+	}
+	delete(l.nodes, nq.id)
+}
+
+// choose selects the next node queue to process.
+func (l *Limiter) choose() *nodeQueue {
+	if l.valueSelect.IsEmpty() || l.selectAddressNext {
+		if ag, ok := l.addressSelect.Choose().(*addressGroup); ok {
+			l.selectAddressNext = false
+			return ag.choose()
+		}
+	}
+	nq, _ := l.valueSelect.Choose().(*nodeQueue)
+	l.selectAddressNext = true
+	return nq
+}
+
+// processLoop processes requests sequentially
+func (l *Limiter) processLoop() {
+	l.lock.Lock()
+	defer l.lock.Unlock()
+
+	for {
+		if l.quit {
+			for _, nq := range l.nodes {
+				for _, request := range nq.queue {
+					close(request.process)
+				}
+			}
+			return
+		}
+		nq := l.choose()
+		if nq == nil {
+			l.cond.Wait()
+			continue
+		}
+		if nq.queue != nil {
+			request := nq.queue[0]
+			nq.queue = nq.queue[1:]
+			nq.sumCost -= request.cost
+			l.sumCost -= request.cost
+			l.lock.Unlock()
+			ch := make(chan struct{})
+			request.process <- ch
+			<-ch
+			l.lock.Lock()
+			if len(nq.queue) > 0 {
+				l.update(nq)
+			} else {
+				l.remove(nq)
+			}
+		} else {
+			// penalized queue removed, next request will be added to a clean queue
+			l.remove(nq)
+		}
+	}
+}
+
+// Stop stops the processing loop. All queued and future requests are rejected.
+func (l *Limiter) Stop() {
+	l.lock.Lock()
+	defer l.lock.Unlock()
+
+	l.quit = true
+	l.cond.Signal()
+}
+
+type (
+	dropList     []dropListItem
+	dropListItem struct {
+		nq       *nodeQueue
+		priority float64
+	}
+)
+
+func (l dropList) Len() int {
+	return len(l)
+}
+
+func (l dropList) Less(i, j int) bool {
+	return l[i].priority < l[j].priority
+}
+
+func (l dropList) Swap(i, j int) {
+	l[i], l[j] = l[j], l[i]
+}
+
+// dropRequests selects the nodes with the highest queued request cost to selection
+// weight ratio and drops their queued request. The empty node queues stay in the
+// selectors with a low selection weight in order to penalize these nodes.
+func (l *Limiter) dropRequests() {
+	var (
+		sumValue float64
+		list     dropList
+	)
+	for _, nq := range l.nodes {
+		sumValue += nq.value
+	}
+	for _, nq := range l.nodes {
+		if nq.sumCost == 0 {
+			continue
+		}
+		w := 1 / float64(len(l.addresses)*len(l.addresses[nq.address].nodes))
+		if sumValue > 0 {
+			w += nq.value / sumValue
+		}
+		list = append(list, dropListItem{
+			nq:       nq,
+			priority: w / float64(nq.sumCost),
+		})
+	}
+	sort.Sort(list)
+	for _, item := range list {
+		for _, request := range item.nq.queue {
+			close(request.process)
+		}
+		// make the queue penalized; no more requests are accepted until the node is
+		// selected based on the penalty cost which is the cumulative cost of all dropped
+		// requests. This ensures that sending excess requests is always penalized
+		// and incentivizes the sender to stop for a while if no replies are received.
+		item.nq.queue = nil
+		item.nq.penaltyCost = item.nq.sumCost
+		l.sumCost -= item.nq.sumCost // penalty costs are not counted in sumCost
+		item.nq.sumCost = 0
+		l.update(item.nq)
+		if l.sumCost <= l.sumCostLimit/2 {
+			return
+		}
+	}
+}

+ 206 - 0
les/utils/limiter_test.go

@@ -0,0 +1,206 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package utils
+
+import (
+	"math/rand"
+	"testing"
+
+	"github.com/ethereum/go-ethereum/p2p/enode"
+)
+
+const (
+	ltTolerance = 0.03
+	ltRounds    = 7
+)
+
+type (
+	ltNode struct {
+		addr, id         int
+		value, exp       float64
+		cost             uint
+		reqRate          float64
+		reqMax, runCount int
+		lastTotalCost    uint
+
+		served, dropped int
+	}
+
+	ltResult struct {
+		node *ltNode
+		ch   chan struct{}
+	}
+
+	limTest struct {
+		limiter            *Limiter
+		results            chan ltResult
+		runCount           int
+		expCost, totalCost uint
+	}
+)
+
+func (lt *limTest) request(n *ltNode) {
+	var (
+		address string
+		id      enode.ID
+	)
+	if n.addr >= 0 {
+		address = string([]byte{byte(n.addr)})
+	} else {
+		var b [32]byte
+		rand.Read(b[:])
+		address = string(b[:])
+	}
+	if n.id >= 0 {
+		id = enode.ID{byte(n.id)}
+	} else {
+		rand.Read(id[:])
+	}
+	lt.runCount++
+	n.runCount++
+	cch := lt.limiter.Add(id, address, n.value, n.cost)
+	go func() {
+		lt.results <- ltResult{n, <-cch}
+	}()
+}
+
+func (lt *limTest) moreRequests(n *ltNode) {
+	maxStart := int(float64(lt.totalCost-n.lastTotalCost) * n.reqRate)
+	if maxStart != 0 {
+		n.lastTotalCost = lt.totalCost
+	}
+	for n.reqMax > n.runCount && maxStart > 0 {
+		lt.request(n)
+		maxStart--
+	}
+}
+
+func (lt *limTest) process() {
+	res := <-lt.results
+	lt.runCount--
+	res.node.runCount--
+	if res.ch != nil {
+		res.node.served++
+		if res.node.exp != 0 {
+			lt.expCost += res.node.cost
+		}
+		lt.totalCost += res.node.cost
+		close(res.ch)
+	} else {
+		res.node.dropped++
+	}
+}
+
+func TestLimiter(t *testing.T) {
+	limTests := [][]*ltNode{
+		{ // one id from an individual address and two ids from a shared address
+			{addr: 0, id: 0, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.5},
+			{addr: 1, id: 1, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25},
+			{addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25},
+		},
+		{ // varying request costs
+			{addr: 0, id: 0, value: 0, cost: 10, reqRate: 0.2, reqMax: 1, exp: 0.5},
+			{addr: 1, id: 1, value: 0, cost: 3, reqRate: 0.5, reqMax: 1, exp: 0.25},
+			{addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25},
+		},
+		{ // different request rate
+			{addr: 0, id: 0, value: 0, cost: 1, reqRate: 2, reqMax: 2, exp: 0.5},
+			{addr: 1, id: 1, value: 0, cost: 1, reqRate: 10, reqMax: 10, exp: 0.25},
+			{addr: 1, id: 2, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25},
+		},
+		{ // adding value
+			{addr: 0, id: 0, value: 3, cost: 1, reqRate: 1, reqMax: 1, exp: (0.5 + 0.3) / 2},
+			{addr: 1, id: 1, value: 0, cost: 1, reqRate: 1, reqMax: 1, exp: 0.25 / 2},
+			{addr: 1, id: 2, value: 7, cost: 1, reqRate: 1, reqMax: 1, exp: (0.25 + 0.7) / 2},
+		},
+		{ // DoS attack from a single address with a single id
+			{addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 3, id: 3, value: 0, cost: 1, reqRate: 10, reqMax: 1000000000, exp: 0},
+		},
+		{ // DoS attack from a single address with different ids
+			{addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 3, id: -1, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0},
+		},
+		{ // DDoS attack from different addresses with a single id
+			{addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: -1, id: 3, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0},
+		},
+		{ // DDoS attack from different addresses with different ids
+			{addr: 0, id: 0, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 1, id: 1, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: 2, id: 2, value: 1, cost: 1, reqRate: 1, reqMax: 1, exp: 0.3333},
+			{addr: -1, id: -1, value: 0, cost: 1, reqRate: 1, reqMax: 1000000000, exp: 0},
+		},
+	}
+
+	lt := &limTest{
+		limiter: NewLimiter(100),
+		results: make(chan ltResult),
+	}
+	for _, test := range limTests {
+		lt.expCost, lt.totalCost = 0, 0
+		iterCount := 10000
+		for j := 0; j < ltRounds; j++ {
+			// try to reach expected target range in multiple rounds with increasing iteration counts
+			last := j == ltRounds-1
+			for _, n := range test {
+				lt.request(n)
+			}
+			for i := 0; i < iterCount; i++ {
+				lt.process()
+				for _, n := range test {
+					lt.moreRequests(n)
+				}
+			}
+			for lt.runCount > 0 {
+				lt.process()
+			}
+			if spamRatio := 1 - float64(lt.expCost)/float64(lt.totalCost); spamRatio > 0.5*(1+ltTolerance) {
+				t.Errorf("Spam ratio too high (%f)", spamRatio)
+			}
+			fail, success := false, true
+			for _, n := range test {
+				if n.exp != 0 {
+					if n.dropped > 0 {
+						t.Errorf("Dropped %d requests of non-spam node", n.dropped)
+						fail = true
+					}
+					r := float64(n.served) * float64(n.cost) / float64(lt.expCost)
+					if r < n.exp*(1-ltTolerance) || r > n.exp*(1+ltTolerance) {
+						if last {
+							// print error only if the target is still not reached in the last round
+							t.Errorf("Request ratio (%f) does not match expected value (%f)", r, n.exp)
+						}
+						success = false
+					}
+				}
+			}
+			if fail || success {
+				break
+			}
+			// neither failed nor succeeded; try more iterations to reach probability targets
+			iterCount *= 2
+		}
+	}
+	lt.limiter.Stop()
+}

+ 14 - 14
les/utils/weighted_select.go

@@ -52,17 +52,17 @@ func (w *WeightedRandomSelect) Remove(item WrsItem) {
 
 // IsEmpty returns true if the set is empty
 func (w *WeightedRandomSelect) IsEmpty() bool {
-	return w.root.sumWeight == 0
+	return w.root.sumCost == 0
 }
 
 // setWeight sets an item's weight to a specific value (removes it if zero)
 func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
-	if weight > math.MaxInt64-w.root.sumWeight {
-		// old weight is still included in sumWeight, remove and check again
+	if weight > math.MaxInt64-w.root.sumCost {
+		// old weight is still included in sumCost, remove and check again
 		w.setWeight(item, 0)
-		if weight > math.MaxInt64-w.root.sumWeight {
-			log.Error("WeightedRandomSelect overflow", "sumWeight", w.root.sumWeight, "new weight", weight)
-			weight = math.MaxInt64 - w.root.sumWeight
+		if weight > math.MaxInt64-w.root.sumCost {
+			log.Error("WeightedRandomSelect overflow", "sumCost", w.root.sumCost, "new weight", weight)
+			weight = math.MaxInt64 - w.root.sumCost
 		}
 	}
 	idx, ok := w.idx[item]
@@ -75,9 +75,9 @@ func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
 		if weight != 0 {
 			if w.root.itemCnt == w.root.maxItems {
 				// add a new level
-				newRoot := &wrsNode{sumWeight: w.root.sumWeight, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches}
+				newRoot := &wrsNode{sumCost: w.root.sumCost, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches}
 				newRoot.items[0] = w.root
-				newRoot.weights[0] = w.root.sumWeight
+				newRoot.weights[0] = w.root.sumCost
 				w.root = newRoot
 			}
 			w.idx[item] = w.root.insert(item, weight)
@@ -91,10 +91,10 @@ func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
 // updates its weight and selects another one
 func (w *WeightedRandomSelect) Choose() WrsItem {
 	for {
-		if w.root.sumWeight == 0 {
+		if w.root.sumCost == 0 {
 			return nil
 		}
-		val := uint64(rand.Int63n(int64(w.root.sumWeight)))
+		val := uint64(rand.Int63n(int64(w.root.sumCost)))
 		choice, lastWeight := w.root.choose(val)
 		weight := w.wfn(choice)
 		if weight != lastWeight {
@@ -112,7 +112,7 @@ const wrsBranches = 8 // max number of branches in the wrsNode tree
 type wrsNode struct {
 	items                    [wrsBranches]interface{}
 	weights                  [wrsBranches]uint64
-	sumWeight                uint64
+	sumCost                  uint64
 	level, itemCnt, maxItems int
 }
 
@@ -126,7 +126,7 @@ func (n *wrsNode) insert(item WrsItem, weight uint64) int {
 		}
 	}
 	n.itemCnt++
-	n.sumWeight += weight
+	n.sumCost += weight
 	n.weights[branch] += weight
 	if n.level == 0 {
 		n.items[branch] = item
@@ -150,7 +150,7 @@ func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
 		oldWeight := n.weights[idx]
 		n.weights[idx] = weight
 		diff := weight - oldWeight
-		n.sumWeight += diff
+		n.sumCost += diff
 		if weight == 0 {
 			n.items[idx] = nil
 			n.itemCnt--
@@ -161,7 +161,7 @@ func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
 	branch := idx / branchItems
 	diff := n.items[branch].(*wrsNode).setWeight(idx-branch*branchItems, weight)
 	n.weights[branch] += diff
-	n.sumWeight += diff
+	n.sumCost += diff
 	if weight == 0 {
 		n.itemCnt--
 	}