Browse Source

p2p/testing: check for all expectations in TestExchanges

Handle all expectations in ProtocolSession.TestExchanges in any
order that are received.
Janos Guljas 7 years ago
parent
commit
e07603bbc4
2 changed files with 200 additions and 57 deletions
  1. 126 52
      p2p/testing/protocolsession.go
  2. 74 5
      p2p/testing/protocoltester.go

+ 126 - 52
p2p/testing/protocolsession.go

@@ -19,13 +19,17 @@ package testing
 import (
 	"errors"
 	"fmt"
+	"sync"
 	"time"
 
+	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p/discover"
 	"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
 )
 
+var errTimedOut = errors.New("timed out")
+
 // ProtocolSession is a quasi simulation of a pivot node running
 // a service and a number of dummy peers that can send (trigger) or
 // receive (expect) messages
@@ -46,6 +50,7 @@ type Exchange struct {
 	Label    string
 	Triggers []Trigger
 	Expects  []Expect
+	Timeout  time.Duration
 }
 
 // Trigger is part of the exchange, incoming message for the pivot node
@@ -102,76 +107,145 @@ func (self *ProtocolSession) trigger(trig Trigger) error {
 }
 
 // expect checks an expectation of a message sent out by the pivot node
-func (self *ProtocolSession) expect(exp Expect) error {
-	if exp.Msg == nil {
-		return errors.New("no message to expect")
-	}
-	simNode, ok := self.adapter.GetNode(exp.Peer)
-	if !ok {
-		return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs))
+func (self *ProtocolSession) expect(exps []Expect) error {
+	// construct a map of expectations for each node
+	peerExpects := make(map[discover.NodeID][]Expect)
+	for _, exp := range exps {
+		if exp.Msg == nil {
+			return errors.New("no message to expect")
+		}
+		peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp)
 	}
-	mockNode, ok := simNode.Services()[0].(*mockNode)
-	if !ok {
-		return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer)
+
+	// construct a map of mockNodes for each node
+	mockNodes := make(map[discover.NodeID]*mockNode)
+	for nodeID := range peerExpects {
+		simNode, ok := self.adapter.GetNode(nodeID)
+		if !ok {
+			return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs))
+		}
+		mockNode, ok := simNode.Services()[0].(*mockNode)
+		if !ok {
+			return fmt.Errorf("trigger: peer %v is not a mock", nodeID)
+		}
+		mockNodes[nodeID] = mockNode
 	}
 
+	// done chanell cancels all created goroutines when function returns
+	done := make(chan struct{})
+	defer close(done)
+	// errc catches the first error from
 	errc := make(chan error)
+
+	wg := &sync.WaitGroup{}
+	wg.Add(len(mockNodes))
+	for nodeID, mockNode := range mockNodes {
+		nodeID := nodeID
+		mockNode := mockNode
+		go func() {
+			defer wg.Done()
+
+			// Sum all Expect timeouts to give the maximum
+			// time for all expectations to finish.
+			// mockNode.Expect checks all received messages against
+			// a list of expected messages and timeout for each
+			// of them can not be checked separately.
+			var t time.Duration
+			for _, exp := range peerExpects[nodeID] {
+				if exp.Timeout == time.Duration(0) {
+					t += 2000 * time.Millisecond
+				} else {
+					t += exp.Timeout
+				}
+			}
+			alarm := time.NewTimer(t)
+			defer alarm.Stop()
+
+			// expectErrc is used to check if error returned
+			// from mockNode.Expect is not nil and to send it to
+			// errc only in that case.
+			// done channel will be closed when function
+			expectErrc := make(chan error)
+			go func() {
+				select {
+				case expectErrc <- mockNode.Expect(peerExpects[nodeID]...):
+				case <-done:
+				case <-alarm.C:
+				}
+			}()
+
+			select {
+			case err := <-expectErrc:
+				if err != nil {
+					select {
+					case errc <- err:
+					case <-done:
+					case <-alarm.C:
+						errc <- errTimedOut
+					}
+				}
+			case <-done:
+			case <-alarm.C:
+				errc <- errTimedOut
+			}
+
+		}()
+	}
+
 	go func() {
-		errc <- mockNode.Expect(&exp)
+		wg.Wait()
+		// close errc when all goroutines finish to return nill err from errc
+		close(errc)
 	}()
 
-	t := exp.Timeout
-	if t == time.Duration(0) {
-		t = 2000 * time.Millisecond
-	}
-	select {
-	case err := <-errc:
-		return err
-	case <-time.After(t):
-		return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer)
-	}
+	return <-errc
 }
 
 // TestExchanges tests a series of exchanges against the session
 func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error {
-	// launch all triggers of this exchanges
+	for i, e := range exchanges {
+		if err := self.testExchange(e); err != nil {
+			return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err)
+		}
+		log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label))
+	}
+	return nil
+}
+
+// testExchange tests a single Exchange.
+// Default timeout value is 2 seconds.
+func (self *ProtocolSession) testExchange(e Exchange) error {
+	errc := make(chan error)
+	done := make(chan struct{})
+	defer close(done)
 
-	for _, e := range exchanges {
-		errc := make(chan error, len(e.Triggers)+len(e.Expects))
+	go func() {
 		for _, trig := range e.Triggers {
-			errc <- self.trigger(trig)
+			err := self.trigger(trig)
+			if err != nil {
+				errc <- err
+				return
+			}
 		}
 
-		// each expectation is spawned in separate go-routine
-		// expectations of an exchange are conjunctive but unordered, i.e.,
-		// only all of them arriving constitutes a pass
-		// each expectation is meant to be for a different peer, otherwise they are expected to panic
-		// testing of an exchange blocks until all expectations are decided
-		// an expectation is decided if
-		//  expected message arrives OR
-		// an unexpected message arrives (panic)
-		// times out on their individual timeout
-		for _, ex := range e.Expects {
-			// expect msg spawned to separate go routine
-			go func(exp Expect) {
-				errc <- self.expect(exp)
-			}(ex)
+		select {
+		case errc <- self.expect(e.Expects):
+		case <-done:
 		}
+	}()
 
-		// time out globally or finish when all expectations satisfied
-		timeout := time.After(5 * time.Second)
-		for i := 0; i < len(e.Triggers)+len(e.Expects); i++ {
-			select {
-			case err := <-errc:
-				if err != nil {
-					return fmt.Errorf("exchange failed with: %v", err)
-				}
-			case <-timeout:
-				return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label)
-			}
-		}
+	// time out globally or finish when all expectations satisfied
+	t := e.Timeout
+	if t == 0 {
+		t = 2000 * time.Millisecond
+	}
+	alarm := time.NewTimer(t)
+	select {
+	case err := <-errc:
+		return err
+	case <-alarm.C:
+		return errTimedOut
 	}
-	return nil
 }
 
 // TestDisconnected tests the disconnections given as arguments

+ 74 - 5
p2p/testing/protocoltester.go

@@ -24,7 +24,11 @@ that can be used to send and receive messages
 package testing
 
 import (
+	"bytes"
 	"fmt"
+	"io"
+	"io/ioutil"
+	"strings"
 	"sync"
 	"testing"
 
@@ -34,6 +38,7 @@ import (
 	"github.com/ethereum/go-ethereum/p2p/discover"
 	"github.com/ethereum/go-ethereum/p2p/simulations"
 	"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
+	"github.com/ethereum/go-ethereum/rlp"
 	"github.com/ethereum/go-ethereum/rpc"
 )
 
@@ -152,7 +157,7 @@ type mockNode struct {
 	testNode
 
 	trigger  chan *Trigger
-	expect   chan *Expect
+	expect   chan []Expect
 	err      chan error
 	stop     chan struct{}
 	stopOnce sync.Once
@@ -161,7 +166,7 @@ type mockNode struct {
 func newMockNode() *mockNode {
 	mock := &mockNode{
 		trigger: make(chan *Trigger),
-		expect:  make(chan *Expect),
+		expect:  make(chan []Expect),
 		err:     make(chan error),
 		stop:    make(chan struct{}),
 	}
@@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
 		select {
 		case trig := <-m.trigger:
 			m.err <- p2p.Send(rw, trig.Code, trig.Msg)
-		case exp := <-m.expect:
-			m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg)
+		case exps := <-m.expect:
+			m.err <- expectMsgs(rw, exps)
 		case <-m.stop:
 			return nil
 		}
@@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error {
 	return <-m.err
 }
 
-func (m *mockNode) Expect(exp *Expect) error {
+func (m *mockNode) Expect(exp ...Expect) error {
 	m.expect <- exp
 	return <-m.err
 }
@@ -198,3 +203,67 @@ func (m *mockNode) Stop() error {
 	m.stopOnce.Do(func() { close(m.stop) })
 	return nil
 }
+
+func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
+	matched := make([]bool, len(exps))
+	for {
+		msg, err := rw.ReadMsg()
+		if err != nil {
+			if err == io.EOF {
+				break
+			}
+			return err
+		}
+		actualContent, err := ioutil.ReadAll(msg.Payload)
+		if err != nil {
+			return err
+		}
+		var found bool
+		for i, exp := range exps {
+			if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
+				if matched[i] {
+					return fmt.Errorf("message #%d received two times", i)
+				}
+				matched[i] = true
+				found = true
+				break
+			}
+		}
+		if !found {
+			expected := make([]string, 0)
+			for i, exp := range exps {
+				if matched[i] {
+					continue
+				}
+				expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
+			}
+			return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
+		}
+		done := true
+		for _, m := range matched {
+			if !m {
+				done = false
+				break
+			}
+		}
+		if done {
+			return nil
+		}
+	}
+	for i, m := range matched {
+		if !m {
+			return fmt.Errorf("expected message #%d not received", i)
+		}
+	}
+	return nil
+}
+
+// mustEncodeMsg uses rlp to encode a message.
+// In case of error it panics.
+func mustEncodeMsg(msg interface{}) []byte {
+	contentEnc, err := rlp.EncodeToBytes(msg)
+	if err != nil {
+		panic("content encode error: " + err.Error())
+	}
+	return contentEnc
+}