Просмотр исходного кода

p2p accounting (#17951)

* p2p/protocols: introduced protocol accounting

* p2p/protocols: added TestExchange simulation

* p2p/protocols: add accounting simulation

* p2p/protocols: remove unnecessary tests

* p2p/protocols: comments for accounting simulation

* p2p/protocols: addressed PR comments

* p2p/protocols: finalized accounting implementation

* p2p/protocols: removed unused code

* p2p/protocols: addressed @nonsense PR comments
holisticode 7 лет назад
Родитель
Сommit
8ed4739176

+ 172 - 0
p2p/protocols/accounting.go

@@ -0,0 +1,172 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package protocols
+
+import "github.com/ethereum/go-ethereum/metrics"
+
+//define some metrics
+var (
+	//NOTE: these metrics just define the interfaces and are currently *NOT persisted* over sessions
+	//All metrics are cumulative
+
+	//total amount of units credited
+	mBalanceCredit = metrics.NewRegisteredCounterForced("account.balance.credit", nil)
+	//total amount of units debited
+	mBalanceDebit = metrics.NewRegisteredCounterForced("account.balance.debit", nil)
+	//total amount of bytes credited
+	mBytesCredit = metrics.NewRegisteredCounterForced("account.bytes.credit", nil)
+	//total amount of bytes debited
+	mBytesDebit = metrics.NewRegisteredCounterForced("account.bytes.debit", nil)
+	//total amount of credited messages
+	mMsgCredit = metrics.NewRegisteredCounterForced("account.msg.credit", nil)
+	//total amount of debited messages
+	mMsgDebit = metrics.NewRegisteredCounterForced("account.msg.debit", nil)
+	//how many times local node had to drop remote peers
+	mPeerDrops = metrics.NewRegisteredCounterForced("account.peerdrops", nil)
+	//how many times local node overdrafted and dropped
+	mSelfDrops = metrics.NewRegisteredCounterForced("account.selfdrops", nil)
+)
+
+//Prices defines how prices are being passed on to the accounting instance
+type Prices interface {
+	//Return the Price for a message
+	Price(interface{}) *Price
+}
+
+type Payer bool
+
+const (
+	Sender   = Payer(true)
+	Receiver = Payer(false)
+)
+
+//Price represents the costs of a message
+type Price struct {
+	Value   uint64 //
+	PerByte bool   //True if the price is per byte or for unit
+	Payer   Payer
+}
+
+//For gives back the price for a message
+//A protocol provides the message price in absolute value
+//This method then returns the correct signed amount,
+//depending on who pays, which is identified by the `payer` argument:
+//`Send` will pass a `Sender` payer, `Receive` will pass the `Receiver` argument.
+//Thus: If Sending and sender pays, amount positive, otherwise negative
+//If Receiving, and receiver pays, amount positive, otherwise negative
+func (p *Price) For(payer Payer, size uint32) int64 {
+	price := p.Value
+	if p.PerByte {
+		price *= uint64(size)
+	}
+	if p.Payer == payer {
+		return 0 - int64(price)
+	}
+	return int64(price)
+}
+
+//Balance is the actual accounting instance
+//Balance defines the operations needed for accounting
+//Implementations internally maintain the balance for every peer
+type Balance interface {
+	//Adds amount to the local balance with remote node `peer`;
+	//positive amount = credit local node
+	//negative amount = debit local node
+	Add(amount int64, peer *Peer) error
+}
+
+//Accounting implements the Hook interface
+//It interfaces to the balances through the Balance interface,
+//while interfacing with protocols and its prices through the Prices interface
+type Accounting struct {
+	Balance //interface to accounting logic
+	Prices  //interface to prices logic
+}
+
+func NewAccounting(balance Balance, po Prices) *Accounting {
+	ah := &Accounting{
+		Prices:  po,
+		Balance: balance,
+	}
+	return ah
+}
+
+//Implement Hook.Send
+// Send takes a peer, a size and a msg and
+// - calculates the cost for the local node sending a msg of size to peer using the Prices interface
+// - credits/debits local node using balance interface
+func (ah *Accounting) Send(peer *Peer, size uint32, msg interface{}) error {
+	//get the price for a message (through the protocol spec)
+	price := ah.Price(msg)
+	//this message doesn't need accounting
+	if price == nil {
+		return nil
+	}
+	//evaluate the price for sending messages
+	costToLocalNode := price.For(Sender, size)
+	//do the accounting
+	err := ah.Add(costToLocalNode, peer)
+	//record metrics: just increase counters for user-facing metrics
+	ah.doMetrics(costToLocalNode, size, err)
+	return err
+}
+
+//Implement Hook.Receive
+// Receive takes a peer, a size and a msg and
+// - calculates the cost for the local node receiving a msg of size from peer using the Prices interface
+// - credits/debits local node using balance interface
+func (ah *Accounting) Receive(peer *Peer, size uint32, msg interface{}) error {
+	//get the price for a message (through the protocol spec)
+	price := ah.Price(msg)
+	//this message doesn't need accounting
+	if price == nil {
+		return nil
+	}
+	//evaluate the price for receiving messages
+	costToLocalNode := price.For(Receiver, size)
+	//do the accounting
+	err := ah.Add(costToLocalNode, peer)
+	//record metrics: just increase counters for user-facing metrics
+	ah.doMetrics(costToLocalNode, size, err)
+	return err
+}
+
+//record some metrics
+//this is not an error handling. `err` is returned by both `Send` and `Receive`
+//`err` will only be non-nil if a limit has been violated (overdraft), in which case the peer has been dropped.
+//if the limit has been violated and `err` is thus not nil:
+// * if the price is positive, local node has been credited; thus `err` implicitly signals the REMOTE has been dropped
+// * if the price is negative, local node has been debited, thus `err` implicitly signals LOCAL node "overdraft"
+func (ah *Accounting) doMetrics(price int64, size uint32, err error) {
+	if price > 0 {
+		mBalanceCredit.Inc(price)
+		mBytesCredit.Inc(int64(size))
+		mMsgCredit.Inc(1)
+		if err != nil {
+			//increase the number of times a remote node has been dropped due to "overdraft"
+			mPeerDrops.Inc(1)
+		}
+	} else {
+		mBalanceDebit.Inc(price)
+		mBytesDebit.Inc(int64(size))
+		mMsgDebit.Inc(1)
+		if err != nil {
+			//increase the number of times the local node has done an "overdraft" in respect to other nodes
+			mSelfDrops.Inc(1)
+		}
+	}
+}

+ 310 - 0
p2p/protocols/accounting_simulation_test.go

@@ -0,0 +1,310 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package protocols
+
+import (
+	"context"
+	"flag"
+	"fmt"
+	"math/rand"
+	"reflect"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/mattn/go-colorable"
+
+	"github.com/ethereum/go-ethereum/log"
+	"github.com/ethereum/go-ethereum/rpc"
+
+	"github.com/ethereum/go-ethereum/node"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/p2p/simulations"
+	"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
+)
+
+const (
+	content = "123456789"
+)
+
+var (
+	nodes    = flag.Int("nodes", 30, "number of nodes to create (default 30)")
+	msgs     = flag.Int("msgs", 100, "number of messages sent by node (default 100)")
+	loglevel = flag.Int("loglevel", 0, "verbosity of logs")
+	rawlog   = flag.Bool("rawlog", false, "remove terminal formatting from logs")
+)
+
+func init() {
+	flag.Parse()
+	log.PrintOrigins(true)
+	log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(!*rawlog))))
+}
+
+//TestAccountingSimulation runs a p2p/simulations simulation
+//It creates a *nodes number of nodes, connects each one with each other,
+//then sends out a random selection of messages up to *msgs amount of messages
+//from the test protocol spec.
+//The spec has some accounted messages defined through the Prices interface.
+//The test does accounting for all the message exchanged, and then checks
+//that every node has the same balance with a peer, but with opposite signs.
+//Balance(AwithB) = 0 - Balance(BwithA) or Abs|Balance(AwithB)| == Abs|Balance(BwithA)|
+func TestAccountingSimulation(t *testing.T) {
+	//setup the balances objects for every node
+	bal := newBalances(*nodes)
+	//define the node.Service for this test
+	services := adapters.Services{
+		"accounting": func(ctx *adapters.ServiceContext) (node.Service, error) {
+			return bal.newNode(), nil
+		},
+	}
+	//setup the simulation
+	adapter := adapters.NewSimAdapter(services)
+	net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{DefaultService: "accounting"})
+	defer net.Shutdown()
+
+	// we send msgs messages per node, wait for all messages to arrive
+	bal.wg.Add(*nodes * *msgs)
+	trigger := make(chan enode.ID)
+	go func() {
+		// wait for all of them to arrive
+		bal.wg.Wait()
+		// then trigger a check
+		// the selected node for the trigger is irrelevant,
+		// we just want to trigger the end of the simulation
+		trigger <- net.Nodes[0].ID()
+	}()
+
+	// create nodes and start them
+	for i := 0; i < *nodes; i++ {
+		conf := adapters.RandomNodeConfig()
+		bal.id2n[conf.ID] = i
+		if _, err := net.NewNodeWithConfig(conf); err != nil {
+			t.Fatal(err)
+		}
+		if err := net.Start(conf.ID); err != nil {
+			t.Fatal(err)
+		}
+	}
+	// fully connect nodes
+	for i, n := range net.Nodes {
+		for _, m := range net.Nodes[i+1:] {
+			if err := net.Connect(n.ID(), m.ID()); err != nil {
+				t.Fatal(err)
+			}
+		}
+	}
+
+	// empty action
+	action := func(ctx context.Context) error {
+		return nil
+	}
+	// 	check always checks out
+	check := func(ctx context.Context, id enode.ID) (bool, error) {
+		return true, nil
+	}
+
+	// run simulation
+	timeout := 30 * time.Second
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+	result := simulations.NewSimulation(net).Run(ctx, &simulations.Step{
+		Action:  action,
+		Trigger: trigger,
+		Expect: &simulations.Expectation{
+			Nodes: []enode.ID{net.Nodes[0].ID()},
+			Check: check,
+		},
+	})
+
+	if result.Error != nil {
+		t.Fatal(result.Error)
+	}
+
+	// check if balance matrix is symmetric
+	if err := bal.symmetric(); err != nil {
+		t.Fatal(err)
+	}
+}
+
+// matrix is a matrix of nodes and its balances
+// matrix is in fact a linear array of size n*n,
+// so the balance for any node A with B is at index
+// A*n + B, while the balance of node B with A is at
+// B*n + A
+// (n entries in the array will not be filled -
+//  the balance of a node with itself)
+type matrix struct {
+	n int     //number of nodes
+	m []int64 //array of balances
+}
+
+// create a new matrix
+func newMatrix(n int) *matrix {
+	return &matrix{
+		n: n,
+		m: make([]int64, n*n),
+	}
+}
+
+// called from the testBalance's Add accounting function: register balance change
+func (m *matrix) add(i, j int, v int64) error {
+	// index for the balance of local node i with remote nodde j is
+	// i * number of nodes + remote node
+	mi := i*m.n + j
+	// register that balance
+	m.m[mi] += v
+	return nil
+}
+
+// check that the balances are symmetric:
+// balance of node i with node j is the same as j with i but with inverted signs
+func (m *matrix) symmetric() error {
+	//iterate all nodes
+	for i := 0; i < m.n; i++ {
+		//iterate starting +1
+		for j := i + 1; j < m.n; j++ {
+			log.Debug("bal", "1", i, "2", j, "i,j", m.m[i*m.n+j], "j,i", m.m[j*m.n+i])
+			if m.m[i*m.n+j] != -m.m[j*m.n+i] {
+				return fmt.Errorf("value mismatch. m[%v, %v] = %v; m[%v, %v] = %v", i, j, m.m[i*m.n+j], j, i, m.m[j*m.n+i])
+			}
+		}
+	}
+	return nil
+}
+
+// all the balances
+type balances struct {
+	i int
+	*matrix
+	id2n map[enode.ID]int
+	wg   *sync.WaitGroup
+}
+
+func newBalances(n int) *balances {
+	return &balances{
+		matrix: newMatrix(n),
+		id2n:   make(map[enode.ID]int),
+		wg:     &sync.WaitGroup{},
+	}
+}
+
+// create a new testNode for every node created as part of the service
+func (b *balances) newNode() *testNode {
+	defer func() { b.i++ }()
+	return &testNode{
+		bal:   b,
+		i:     b.i,
+		peers: make([]*testPeer, b.n), //a node will be connected to n-1 peers
+	}
+}
+
+type testNode struct {
+	bal       *balances
+	i         int
+	lock      sync.Mutex
+	peers     []*testPeer
+	peerCount int
+}
+
+// do the accounting for the peer's test protocol
+// testNode implements protocols.Balance
+func (t *testNode) Add(a int64, p *Peer) error {
+	//get the index for the remote peer
+	remote := t.bal.id2n[p.ID()]
+	log.Debug("add", "local", t.i, "remote", remote, "amount", a)
+	return t.bal.add(t.i, remote, a)
+}
+
+//run the p2p protocol
+//for every node, represented by testNode, create a remote testPeer
+func (t *testNode) run(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+	spec := createTestSpec()
+	//create accounting hook
+	spec.Hook = NewAccounting(t, &dummyPrices{})
+
+	//create a peer for this node
+	tp := &testPeer{NewPeer(p, rw, spec), t.i, t.bal.id2n[p.ID()], t.bal.wg}
+	t.lock.Lock()
+	t.peers[t.bal.id2n[p.ID()]] = tp
+	t.peerCount++
+	if t.peerCount == t.bal.n-1 {
+		//when all peer connections are established, start sending messages from this peer
+		go t.send()
+	}
+	t.lock.Unlock()
+	return tp.Run(tp.handle)
+}
+
+// p2p message receive handler function
+func (tp *testPeer) handle(ctx context.Context, msg interface{}) error {
+	tp.wg.Done()
+	log.Debug("receive", "from", tp.remote, "to", tp.local, "type", reflect.TypeOf(msg), "msg", msg)
+	return nil
+}
+
+type testPeer struct {
+	*Peer
+	local, remote int
+	wg            *sync.WaitGroup
+}
+
+func (t *testNode) send() {
+	log.Debug("start sending")
+	for i := 0; i < *msgs; i++ {
+		//determine randomly to which peer to send
+		whom := rand.Intn(t.bal.n - 1)
+		if whom >= t.i {
+			whom++
+		}
+		t.lock.Lock()
+		p := t.peers[whom]
+		t.lock.Unlock()
+
+		//determine a random message from the spec's messages to be sent
+		which := rand.Intn(len(p.spec.Messages))
+		msg := p.spec.Messages[which]
+		switch msg.(type) {
+		case *perBytesMsgReceiverPays:
+			msg = &perBytesMsgReceiverPays{Content: content[:rand.Intn(len(content))]}
+		case *perBytesMsgSenderPays:
+			msg = &perBytesMsgSenderPays{Content: content[:rand.Intn(len(content))]}
+		}
+		log.Debug("send", "from", t.i, "to", whom, "type", reflect.TypeOf(msg), "msg", msg)
+		p.Send(context.TODO(), msg)
+	}
+}
+
+// define the protocol
+func (t *testNode) Protocols() []p2p.Protocol {
+	return []p2p.Protocol{{
+		Length: 100,
+		Run:    t.run,
+	}}
+}
+
+func (t *testNode) APIs() []rpc.API {
+	return nil
+}
+
+func (t *testNode) Start(server *p2p.Server) error {
+	return nil
+}
+
+func (t *testNode) Stop() error {
+	return nil
+}

+ 223 - 0
p2p/protocols/accounting_test.go

@@ -0,0 +1,223 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package protocols
+
+import (
+	"testing"
+
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+//dummy Balance implementation
+type dummyBalance struct {
+	amount int64
+	peer   *Peer
+}
+
+//dummy Prices implementation
+type dummyPrices struct{}
+
+//a dummy message which needs size based accounting
+//sender pays
+type perBytesMsgSenderPays struct {
+	Content string
+}
+
+//a dummy message which needs size based accounting
+//receiver pays
+type perBytesMsgReceiverPays struct {
+	Content string
+}
+
+//a dummy message which is paid for per unit
+//sender pays
+type perUnitMsgSenderPays struct{}
+
+//receiver pays
+type perUnitMsgReceiverPays struct{}
+
+//a dummy message which has zero as its price
+type zeroPriceMsg struct{}
+
+//a dummy message which has no accounting
+type nilPriceMsg struct{}
+
+//return the price for the defined messages
+func (d *dummyPrices) Price(msg interface{}) *Price {
+	switch msg.(type) {
+	//size based message cost, receiver pays
+	case *perBytesMsgReceiverPays:
+		return &Price{
+			PerByte: true,
+			Value:   uint64(100),
+			Payer:   Receiver,
+		}
+	//size based message cost, sender pays
+	case *perBytesMsgSenderPays:
+		return &Price{
+			PerByte: true,
+			Value:   uint64(100),
+			Payer:   Sender,
+		}
+		//unitary cost, receiver pays
+	case *perUnitMsgReceiverPays:
+		return &Price{
+			PerByte: false,
+			Value:   uint64(99),
+			Payer:   Receiver,
+		}
+		//unitary cost, sender pays
+	case *perUnitMsgSenderPays:
+		return &Price{
+			PerByte: false,
+			Value:   uint64(99),
+			Payer:   Sender,
+		}
+	case *zeroPriceMsg:
+		return &Price{
+			PerByte: false,
+			Value:   uint64(0),
+			Payer:   Sender,
+		}
+	case *nilPriceMsg:
+		return nil
+	}
+	return nil
+}
+
+//dummy accounting implementation, only stores values for later check
+func (d *dummyBalance) Add(amount int64, peer *Peer) error {
+	d.amount = amount
+	d.peer = peer
+	return nil
+}
+
+type testCase struct {
+	msg        interface{}
+	size       uint32
+	sendResult int64
+	recvResult int64
+}
+
+//lowest level unit test
+func TestBalance(t *testing.T) {
+	//create instances
+	balance := &dummyBalance{}
+	prices := &dummyPrices{}
+	//create the spec
+	spec := createTestSpec()
+	//create the accounting hook for the spec
+	acc := NewAccounting(balance, prices)
+	//create a peer
+	id := adapters.RandomNodeConfig().ID
+	p := p2p.NewPeer(id, "testPeer", nil)
+	peer := NewPeer(p, &dummyRW{}, spec)
+	//price depends on size, receiver pays
+	msg := &perBytesMsgReceiverPays{Content: "testBalance"}
+	size, _ := rlp.EncodeToBytes(msg)
+
+	testCases := []testCase{
+		{
+			msg,
+			uint32(len(size)),
+			int64(len(size) * 100),
+			int64(len(size) * -100),
+		},
+		{
+			&perBytesMsgSenderPays{Content: "testBalance"},
+			uint32(len(size)),
+			int64(len(size) * -100),
+			int64(len(size) * 100),
+		},
+		{
+			&perUnitMsgSenderPays{},
+			0,
+			int64(-99),
+			int64(99),
+		},
+		{
+			&perUnitMsgReceiverPays{},
+			0,
+			int64(99),
+			int64(-99),
+		},
+		{
+			&zeroPriceMsg{},
+			0,
+			int64(0),
+			int64(0),
+		},
+		{
+			&nilPriceMsg{},
+			0,
+			int64(0),
+			int64(0),
+		},
+	}
+	checkAccountingTestCases(t, testCases, acc, peer, balance, true)
+	checkAccountingTestCases(t, testCases, acc, peer, balance, false)
+}
+
+func checkAccountingTestCases(t *testing.T, cases []testCase, acc *Accounting, peer *Peer, balance *dummyBalance, send bool) {
+	for _, c := range cases {
+		var err error
+		var expectedResult int64
+		//reset balance before every check
+		balance.amount = 0
+		if send {
+			err = acc.Send(peer, c.size, c.msg)
+			expectedResult = c.sendResult
+		} else {
+			err = acc.Receive(peer, c.size, c.msg)
+			expectedResult = c.recvResult
+		}
+
+		checkResults(t, err, balance, peer, expectedResult)
+	}
+}
+
+func checkResults(t *testing.T, err error, balance *dummyBalance, peer *Peer, result int64) {
+	if err != nil {
+		t.Fatal(err)
+	}
+	if balance.peer != peer {
+		t.Fatalf("expected Add to be called with peer %v, got %v", peer, balance.peer)
+	}
+	if balance.amount != result {
+		t.Fatalf("Expected balance to be %d but is %d", result, balance.amount)
+	}
+}
+
+//create a test spec
+func createTestSpec() *Spec {
+	spec := &Spec{
+		Name:       "test",
+		Version:    42,
+		MaxMsgSize: 10 * 1024,
+		Messages: []interface{}{
+			&perBytesMsgReceiverPays{},
+			&perBytesMsgSenderPays{},
+			&perUnitMsgReceiverPays{},
+			&perUnitMsgSenderPays{},
+			&zeroPriceMsg{},
+			&nilPriceMsg{},
+		},
+	}
+	return spec
+}

+ 30 - 0
p2p/protocols/protocol.go

@@ -122,6 +122,16 @@ type WrappedMsg struct {
 	Payload []byte
 }
 
+//For accounting, the design is to allow the Spec to describe which and how its messages are priced
+//To access this functionality, we provide a Hook interface which will call accounting methods
+//NOTE: there could be more such (horizontal) hooks in the future
+type Hook interface {
+	//A hook for sending messages
+	Send(peer *Peer, size uint32, msg interface{}) error
+	//A hook for receiving messages
+	Receive(peer *Peer, size uint32, msg interface{}) error
+}
+
 // Spec is a protocol specification including its name and version as well as
 // the types of messages which are exchanged
 type Spec struct {
@@ -141,6 +151,9 @@ type Spec struct {
 	// each message must have a single unique data type
 	Messages []interface{}
 
+	//hook for accounting (could be extended to multiple hooks in the future)
+	Hook Hook
+
 	initOnce sync.Once
 	codes    map[reflect.Type]uint64
 	types    map[uint64]reflect.Type
@@ -274,6 +287,15 @@ func (p *Peer) Send(ctx context.Context, msg interface{}) error {
 		Payload: r,
 	}
 
+	//if the accounting hook is set, call it
+	if p.spec.Hook != nil {
+		err := p.spec.Hook.Send(p, wmsg.Size, msg)
+		if err != nil {
+			p.Drop(err)
+			return err
+		}
+	}
+
 	code, found := p.spec.GetCode(msg)
 	if !found {
 		return errorf(ErrInvalidMsgType, "%v", code)
@@ -336,6 +358,14 @@ func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{})
 		return errorf(ErrDecode, "<= %v: %v", msg, err)
 	}
 
+	//if the accounting hook is set, call it
+	if p.spec.Hook != nil {
+		err := p.spec.Hook.Receive(p, wmsg.Size, val)
+		if err != nil {
+			return err
+		}
+	}
+
 	// call the registered handler callbacks
 	// a registered callback take the decoded message as argument as an interface
 	// which the handler is supposed to cast to the appropriate type

+ 202 - 0
p2p/protocols/protocol_test.go

@@ -17,12 +17,15 @@
 package protocols
 
 import (
+	"bytes"
 	"context"
 	"errors"
 	"fmt"
 	"testing"
 	"time"
 
+	"github.com/ethereum/go-ethereum/rlp"
+
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
@@ -185,6 +188,169 @@ func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) {
 	}
 }
 
+type dummyHook struct {
+	peer  *Peer
+	size  uint32
+	msg   interface{}
+	send  bool
+	err   error
+	waitC chan struct{}
+}
+
+type dummyMsg struct {
+	Content string
+}
+
+func (d *dummyHook) Send(peer *Peer, size uint32, msg interface{}) error {
+	d.peer = peer
+	d.size = size
+	d.msg = msg
+	d.send = true
+	return d.err
+}
+
+func (d *dummyHook) Receive(peer *Peer, size uint32, msg interface{}) error {
+	d.peer = peer
+	d.size = size
+	d.msg = msg
+	d.send = false
+	d.waitC <- struct{}{}
+	return d.err
+}
+
+func TestProtocolHook(t *testing.T) {
+	testHook := &dummyHook{
+		waitC: make(chan struct{}, 1),
+	}
+	spec := &Spec{
+		Name:       "test",
+		Version:    42,
+		MaxMsgSize: 10 * 1024,
+		Messages: []interface{}{
+			dummyMsg{},
+		},
+		Hook: testHook,
+	}
+
+	runFunc := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
+		peer := NewPeer(p, rw, spec)
+		ctx := context.TODO()
+		err := peer.Send(ctx, &dummyMsg{
+			Content: "handshake"})
+
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		handle := func(ctx context.Context, msg interface{}) error {
+			return nil
+		}
+
+		return peer.Run(handle)
+	}
+
+	conf := adapters.RandomNodeConfig()
+	tester := p2ptest.NewProtocolTester(t, conf.ID, 2, runFunc)
+	err := tester.TestExchanges(p2ptest.Exchange{
+		Expects: []p2ptest.Expect{
+			{
+				Code: 0,
+				Msg:  &dummyMsg{Content: "handshake"},
+				Peer: tester.Nodes[0].ID(),
+			},
+		},
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "handshake" {
+		t.Fatal("Expected msg to be set, but it is not")
+	}
+	if !testHook.send {
+		t.Fatal("Expected a send message, but it is not")
+	}
+	if testHook.peer == nil || testHook.peer.ID() != tester.Nodes[0].ID() {
+		t.Fatal("Expected peer ID to be set correctly, but it is not")
+	}
+	if testHook.size != 11 { //11 is the length of the encoded message
+		t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size)
+	}
+
+	err = tester.TestExchanges(p2ptest.Exchange{
+		Triggers: []p2ptest.Trigger{
+			{
+				Code: 0,
+				Msg:  &dummyMsg{Content: "response"},
+				Peer: tester.Nodes[1].ID(),
+			},
+		},
+	})
+
+	<-testHook.waitC
+
+	if err != nil {
+		t.Fatal(err)
+	}
+	if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "response" {
+		t.Fatal("Expected msg to be set, but it is not")
+	}
+	if testHook.send {
+		t.Fatal("Expected a send message, but it is not")
+	}
+	if testHook.peer == nil || testHook.peer.ID() != tester.Nodes[1].ID() {
+		t.Fatal("Expected peer ID to be set correctly, but it is not")
+	}
+	if testHook.size != 10 { //11 is the length of the encoded message
+		t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size)
+	}
+
+	testHook.err = fmt.Errorf("dummy error")
+	err = tester.TestExchanges(p2ptest.Exchange{
+		Triggers: []p2ptest.Trigger{
+			{
+				Code: 0,
+				Msg:  &dummyMsg{Content: "response"},
+				Peer: tester.Nodes[1].ID(),
+			},
+		},
+	})
+
+	<-testHook.waitC
+
+	time.Sleep(100 * time.Millisecond)
+	err = tester.TestDisconnected(&p2ptest.Disconnect{tester.Nodes[1].ID(), testHook.err})
+	if err != nil {
+		t.Fatalf("Expected a specific disconnect error, but got different one: %v", err)
+	}
+
+}
+
+//We need to test that if the hook is not defined, then message infrastructure
+//(send,receive) still works
+func TestNoHook(t *testing.T) {
+	//create a test spec
+	spec := createTestSpec()
+	//a random node
+	id := adapters.RandomNodeConfig().ID
+	//a peer
+	p := p2p.NewPeer(id, "testPeer", nil)
+	rw := &dummyRW{}
+	peer := NewPeer(p, rw, spec)
+	ctx := context.TODO()
+	msg := &perBytesMsgSenderPays{Content: "testBalance"}
+	//send a message
+	err := peer.Send(ctx, msg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	//simulate receiving a message
+	rw.msg = msg
+	peer.handleIncoming(func(ctx context.Context, msg interface{}) error {
+		return nil
+	})
+	//all should just work and not result in any error
+}
+
 func TestProtoHandshakeVersionMismatch(t *testing.T) {
 	runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error()))
 }
@@ -386,3 +552,39 @@ func XTestMultiplePeersDropOther(t *testing.T) {
 		fmt.Errorf("subprotocol error"),
 	)
 }
+
+//dummy implementation of a MsgReadWriter
+//this allows for quick and easy unit tests without
+//having to build up the complete protocol
+type dummyRW struct {
+	msg  interface{}
+	size uint32
+	code uint64
+}
+
+func (d *dummyRW) WriteMsg(msg p2p.Msg) error {
+	return nil
+}
+
+func (d *dummyRW) ReadMsg() (p2p.Msg, error) {
+	enc := bytes.NewReader(d.getDummyMsg())
+	return p2p.Msg{
+		Code:       d.code,
+		Size:       d.size,
+		Payload:    enc,
+		ReceivedAt: time.Now(),
+	}, nil
+}
+
+func (d *dummyRW) getDummyMsg() []byte {
+	r, _ := rlp.EncodeToBytes(d.msg)
+	var b bytes.Buffer
+	wmsg := WrappedMsg{
+		Context: b.Bytes(),
+		Size:    uint32(len(r)),
+		Payload: r,
+	}
+	rr, _ := rlp.EncodeToBytes(wmsg)
+	d.size = uint32(len(rr))
+	return rr
+}