Prechádzať zdrojové kódy

p2p: new dial scheduler (#20592)

* p2p: new dial scheduler

This change replaces the peer-to-peer dial scheduler with a new and
improved implementation. The new code is better than the previous
implementation in two key aspects:

- The time between discovery of a node and dialing that node is
  significantly lower in the new version. The old dialState kept
  a buffer of nodes and launched a task to refill it whenever the buffer
  became empty. This worked well with the discovery interface we used to
  have, but doesn't really work with the new iterator-based discovery
  API.

- Selection of static dial candidates (created by Server.AddPeer or
  through static-nodes.json) performs much better for large amounts of
  static peers. Connections to static nodes are now limited like dynanic
  dials and can no longer overstep MaxPeers or the dial ratio.

* p2p/simulations/adapters: adapt to new NodeDialer interface

* p2p: re-add check for self in checkDial

* p2p: remove peersetCh

* p2p: allow static dials when discovery is disabled

* p2p: add test for dialScheduler.removeStatic

* p2p: remove blank line

* p2p: fix documentation of maxDialPeers

* p2p: change "ok" to "added" in static node log

* p2p: improve dialTask docs

Also increase log level for "Can't resolve node"

* p2p: ensure dial resolver is truly nil without discovery

* p2p: add "looking for peers" log message

* p2p: clean up Server.run comments

* p2p: fix maxDialedConns for maxpeers < dialRatio

Always allocate at least one dial slot unless dialing is disabled using
NoDial or MaxPeers == 0. Most importantly, this fixes MaxPeers == 1 to
dedicate the sole slot to dialing instead of listening.

* p2p: fix RemovePeer to disconnect the peer again

Also make RemovePeer synchronous and add a test.

* p2p: remove "Connection set up" log message

* p2p: clean up connection logging

We previously logged outgoing connection failures up to three times.

- in SetupConn() as "Setting up connection failed addr=..."
- in setupConn() with an error-specific message and "id=... addr=..."
- in dial() as "Dial error task=..."

This commit ensures a single log message is emitted per failure and adds
"id=... addr=... conn=..." everywhere (id= omitted when the ID isn't
known yet).

Also avoid printing a log message when a static dial fails but can't be
resolved because discv4 is disabled. The light client hit this case all
the time, increasing the message count to four lines per failed
connection.

* p2p: document that RemovePeer blocks
Felix Lange 5 rokov pred
rodič
commit
90caa2cabb
8 zmenil súbory, kde vykonal 1244 pridanie a 1048 odobranie
  1. 395 196
      p2p/dial.go
  2. 583 501
      p2p/dial_test.go
  3. 42 2
      p2p/peer_test.go
  4. 144 176
      p2p/server.go
  5. 59 159
      p2p/server_test.go
  6. 2 1
      p2p/simulations/adapters/inproc.go
  7. 12 8
      p2p/util.go
  8. 7 5
      p2p/util_test.go

+ 395 - 196
p2p/dial.go

@@ -17,11 +17,17 @@
 package p2p
 
 import (
+	"context"
+	crand "crypto/rand"
+	"encoding/binary"
 	"errors"
 	"fmt"
+	mrand "math/rand"
 	"net"
+	"sync"
 	"time"
 
+	"github.com/ethereum/go-ethereum/common/mclock"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/p2p/netutil"
@@ -33,8 +39,9 @@ const (
 	// private networks.
 	dialHistoryExpiration = inboundThrottleTime + 5*time.Second
 
-	// If no peers are found for this amount of time, the initial bootnodes are dialed.
-	fallbackInterval = 20 * time.Second
+	// Config for the "Looking for peers" message.
+	dialStatsLogInterval = 10 * time.Second // printed at most this often
+	dialStatsPeerLimit   = 3                // but not if more than this many dialed peers
 
 	// Endpoint resolution is throttled with bounded backoff.
 	initialResolveDelay = 60 * time.Second
@@ -42,219 +49,443 @@ const (
 )
 
 // NodeDialer is used to connect to nodes in the network, typically by using
-// an underlying net.Dialer but also using net.Pipe in tests
+// an underlying net.Dialer but also using net.Pipe in tests.
 type NodeDialer interface {
-	Dial(*enode.Node) (net.Conn, error)
+	Dial(context.Context, *enode.Node) (net.Conn, error)
 }
 
 type nodeResolver interface {
 	Resolve(*enode.Node) *enode.Node
 }
 
-// TCPDialer implements the NodeDialer interface by using a net.Dialer to
-// create TCP connections to nodes in the network
-type TCPDialer struct {
-	*net.Dialer
+// tcpDialer implements NodeDialer using real TCP connections.
+type tcpDialer struct {
+	d *net.Dialer
 }
 
-// Dial creates a TCP connection to the node
-func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) {
-	addr := &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()}
-	return t.Dialer.Dial("tcp", addr.String())
+func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) {
+	return t.d.DialContext(ctx, "tcp", nodeAddr(dest).String())
 }
 
-// dialstate schedules dials and discovery lookups.
-// It gets a chance to compute new tasks on every iteration
-// of the main loop in Server.run.
-type dialstate struct {
-	maxDynDials int
-	netrestrict *netutil.Netlist
-	self        enode.ID
-	bootnodes   []*enode.Node // default dials when there are no peers
-	log         log.Logger
+func nodeAddr(n *enode.Node) net.Addr {
+	return &net.TCPAddr{IP: n.IP(), Port: n.TCP()}
+}
+
+// checkDial errors:
+var (
+	errSelf             = errors.New("is self")
+	errAlreadyDialing   = errors.New("already dialing")
+	errAlreadyConnected = errors.New("already connected")
+	errRecentlyDialed   = errors.New("recently dialed")
+	errNotWhitelisted   = errors.New("not contained in netrestrict whitelist")
+)
+
+// dialer creates outbound connections and submits them into Server.
+// Two types of peer connections can be created:
+//
+//  - static dials are pre-configured connections. The dialer attempts
+//    keep these nodes connected at all times.
+//
+//  - dynamic dials are created from node discovery results. The dialer
+//    continuously reads candidate nodes from its input iterator and attempts
+//    to create peer connections to nodes arriving through the iterator.
+//
+type dialScheduler struct {
+	dialConfig
+	setupFunc   dialSetupFunc
+	wg          sync.WaitGroup
+	cancel      context.CancelFunc
+	ctx         context.Context
+	nodesIn     chan *enode.Node
+	doneCh      chan *dialTask
+	addStaticCh chan *enode.Node
+	remStaticCh chan *enode.Node
+	addPeerCh   chan *conn
+	remPeerCh   chan *conn
+
+	// Everything below here belongs to loop and
+	// should only be accessed by code on the loop goroutine.
+	dialing   map[enode.ID]*dialTask // active tasks
+	peers     map[enode.ID]connFlag  // all connected peers
+	dialPeers int                    // current number of dialed peers
+
+	// The static map tracks all static dial tasks. The subset of usable static dial tasks
+	// (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers
+	// launching random static tasks from the pool over launching dynamic dials from the
+	// iterator.
+	static     map[enode.ID]*dialTask
+	staticPool []*dialTask
+
+	// The dial history keeps recently dialed nodes. Members of history are not dialed.
+	history          expHeap
+	historyTimer     mclock.Timer
+	historyTimerTime mclock.AbsTime
+
+	// for logStats
+	lastStatsLog     mclock.AbsTime
+	doneSinceLastLog int
+}
 
-	start         time.Time // time when the dialer was first used
-	lookupRunning bool
-	dialing       map[enode.ID]connFlag
-	lookupBuf     []*enode.Node // current discovery lookup results
-	static        map[enode.ID]*dialTask
-	hist          expHeap
+type dialSetupFunc func(net.Conn, connFlag, *enode.Node) error
+
+type dialConfig struct {
+	self           enode.ID         // our own ID
+	maxDialPeers   int              // maximum number of dialed peers
+	maxActiveDials int              // maximum number of active dials
+	netRestrict    *netutil.Netlist // IP whitelist, disabled if nil
+	resolver       nodeResolver
+	dialer         NodeDialer
+	log            log.Logger
+	clock          mclock.Clock
+	rand           *mrand.Rand
 }
 
-type task interface {
-	Do(*Server)
+func (cfg dialConfig) withDefaults() dialConfig {
+	if cfg.maxActiveDials == 0 {
+		cfg.maxActiveDials = defaultMaxPendingPeers
+	}
+	if cfg.log == nil {
+		cfg.log = log.Root()
+	}
+	if cfg.clock == nil {
+		cfg.clock = mclock.System{}
+	}
+	if cfg.rand == nil {
+		seedb := make([]byte, 8)
+		crand.Read(seedb)
+		seed := int64(binary.BigEndian.Uint64(seedb))
+		cfg.rand = mrand.New(mrand.NewSource(seed))
+	}
+	return cfg
 }
 
-func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate {
-	s := &dialstate{
-		maxDynDials: maxdyn,
-		self:        self,
-		netrestrict: cfg.NetRestrict,
-		log:         cfg.Logger,
+func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler {
+	d := &dialScheduler{
+		dialConfig:  config.withDefaults(),
+		setupFunc:   setupFunc,
+		dialing:     make(map[enode.ID]*dialTask),
 		static:      make(map[enode.ID]*dialTask),
-		dialing:     make(map[enode.ID]connFlag),
-		bootnodes:   make([]*enode.Node, len(cfg.BootstrapNodes)),
+		peers:       make(map[enode.ID]connFlag),
+		doneCh:      make(chan *dialTask),
+		nodesIn:     make(chan *enode.Node),
+		addStaticCh: make(chan *enode.Node),
+		remStaticCh: make(chan *enode.Node),
+		addPeerCh:   make(chan *conn),
+		remPeerCh:   make(chan *conn),
 	}
-	copy(s.bootnodes, cfg.BootstrapNodes)
-	if s.log == nil {
-		s.log = log.Root()
+	d.lastStatsLog = d.clock.Now()
+	d.ctx, d.cancel = context.WithCancel(context.Background())
+	d.wg.Add(2)
+	go d.readNodes(it)
+	go d.loop(it)
+	return d
+}
+
+// stop shuts down the dialer, canceling all current dial tasks.
+func (d *dialScheduler) stop() {
+	d.cancel()
+	d.wg.Wait()
+}
+
+// addStatic adds a static dial candidate.
+func (d *dialScheduler) addStatic(n *enode.Node) {
+	select {
+	case d.addStaticCh <- n:
+	case <-d.ctx.Done():
 	}
-	for _, n := range cfg.StaticNodes {
-		s.addStatic(n)
+}
+
+// removeStatic removes a static dial candidate.
+func (d *dialScheduler) removeStatic(n *enode.Node) {
+	select {
+	case d.remStaticCh <- n:
+	case <-d.ctx.Done():
 	}
-	return s
 }
 
-func (s *dialstate) addStatic(n *enode.Node) {
-	// This overwrites the task instead of updating an existing
-	// entry, giving users the opportunity to force a resolve operation.
-	s.static[n.ID()] = &dialTask{flags: staticDialedConn, dest: n}
+// peerAdded updates the peer set.
+func (d *dialScheduler) peerAdded(c *conn) {
+	select {
+	case d.addPeerCh <- c:
+	case <-d.ctx.Done():
+	}
 }
 
-func (s *dialstate) removeStatic(n *enode.Node) {
-	// This removes a task so future attempts to connect will not be made.
-	delete(s.static, n.ID())
+// peerRemoved updates the peer set.
+func (d *dialScheduler) peerRemoved(c *conn) {
+	select {
+	case d.remPeerCh <- c:
+	case <-d.ctx.Done():
+	}
 }
 
-func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
-	var newtasks []task
-	addDial := func(flag connFlag, n *enode.Node) bool {
-		if err := s.checkDial(n, peers); err != nil {
-			s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
-			return false
-		}
-		s.dialing[n.ID()] = flag
-		newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
-		return true
-	}
-
-	if s.start.IsZero() {
-		s.start = now
-	}
-	s.hist.expire(now)
-
-	// Create dials for static nodes if they are not connected.
-	for id, t := range s.static {
-		err := s.checkDial(t.dest, peers)
-		switch err {
-		case errNotWhitelisted, errSelf:
-			s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
-			delete(s.static, t.dest.ID())
-		case nil:
-			s.dialing[id] = t.flags
-			newtasks = append(newtasks, t)
+// loop is the main loop of the dialer.
+func (d *dialScheduler) loop(it enode.Iterator) {
+	var (
+		nodesCh    chan *enode.Node
+		historyExp = make(chan struct{}, 1)
+	)
+
+loop:
+	for {
+		// Launch new dials if slots are available.
+		slots := d.freeDialSlots()
+		slots -= d.startStaticDials(slots)
+		if slots > 0 {
+			nodesCh = d.nodesIn
+		} else {
+			nodesCh = nil
 		}
-	}
+		d.rearmHistoryTimer(historyExp)
+		d.logStats()
+
+		select {
+		case node := <-nodesCh:
+			if err := d.checkDial(node); err != nil {
+				d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IP(), "reason", err)
+			} else {
+				d.startDial(newDialTask(node, dynDialedConn))
+			}
+
+		case task := <-d.doneCh:
+			id := task.dest.ID()
+			delete(d.dialing, id)
+			d.updateStaticPool(id)
+			d.doneSinceLastLog++
+
+		case c := <-d.addPeerCh:
+			if c.is(dynDialedConn) || c.is(staticDialedConn) {
+				d.dialPeers++
+			}
+			id := c.node.ID()
+			d.peers[id] = c.flags
+			// Remove from static pool because the node is now connected.
+			task := d.static[id]
+			if task != nil && task.staticPoolIndex >= 0 {
+				d.removeFromStaticPool(task.staticPoolIndex)
+			}
+			// TODO: cancel dials to connected peers
+
+		case c := <-d.remPeerCh:
+			if c.is(dynDialedConn) || c.is(staticDialedConn) {
+				d.dialPeers--
+			}
+			delete(d.peers, c.node.ID())
+			d.updateStaticPool(c.node.ID())
+
+		case node := <-d.addStaticCh:
+			id := node.ID()
+			_, exists := d.static[id]
+			d.log.Trace("Adding static node", "id", id, "ip", node.IP(), "added", !exists)
+			if exists {
+				continue loop
+			}
+			task := newDialTask(node, staticDialedConn)
+			d.static[id] = task
+			if d.checkDial(node) == nil {
+				d.addToStaticPool(task)
+			}
+
+		case node := <-d.remStaticCh:
+			id := node.ID()
+			task := d.static[id]
+			d.log.Trace("Removing static node", "id", id, "ok", task != nil)
+			if task != nil {
+				delete(d.static, id)
+				if task.staticPoolIndex >= 0 {
+					d.removeFromStaticPool(task.staticPoolIndex)
+				}
+			}
 
-	// Compute number of dynamic dials needed.
-	needDynDials := s.maxDynDials
-	for _, p := range peers {
-		if p.rw.is(dynDialedConn) {
-			needDynDials--
+		case <-historyExp:
+			d.expireHistory()
+
+		case <-d.ctx.Done():
+			it.Close()
+			break loop
 		}
 	}
-	for _, flag := range s.dialing {
-		if flag&dynDialedConn != 0 {
-			needDynDials--
-		}
+
+	d.stopHistoryTimer(historyExp)
+	for range d.dialing {
+		<-d.doneCh
 	}
+	d.wg.Done()
+}
+
+// readNodes runs in its own goroutine and delivers nodes from
+// the input iterator to the nodesIn channel.
+func (d *dialScheduler) readNodes(it enode.Iterator) {
+	defer d.wg.Done()
 
-	// If we don't have any peers whatsoever, try to dial a random bootnode. This
-	// scenario is useful for the testnet (and private networks) where the discovery
-	// table might be full of mostly bad peers, making it hard to find good ones.
-	if len(peers) == 0 && len(s.bootnodes) > 0 && needDynDials > 0 && now.Sub(s.start) > fallbackInterval {
-		bootnode := s.bootnodes[0]
-		s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...)
-		s.bootnodes = append(s.bootnodes, bootnode)
-		if addDial(dynDialedConn, bootnode) {
-			needDynDials--
+	for it.Next() {
+		select {
+		case d.nodesIn <- it.Node():
+		case <-d.ctx.Done():
 		}
 	}
+}
 
-	// Create dynamic dials from discovery results.
-	i := 0
-	for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
-		if addDial(dynDialedConn, s.lookupBuf[i]) {
-			needDynDials--
-		}
+// logStats prints dialer statistics to the log. The message is suppressed when enough
+// peers are connected because users should only see it while their client is starting up
+// or comes back online.
+func (d *dialScheduler) logStats() {
+	now := d.clock.Now()
+	if d.lastStatsLog.Add(dialStatsLogInterval) > now {
+		return
+	}
+	if d.dialPeers < dialStatsPeerLimit && d.dialPeers < d.maxDialPeers {
+		d.log.Info("Looking for peers", "peercount", len(d.peers), "tried", d.doneSinceLastLog, "static", len(d.static))
 	}
-	s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
+	d.doneSinceLastLog = 0
+	d.lastStatsLog = now
+}
 
-	// Launch a discovery lookup if more candidates are needed.
-	if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
-		s.lookupRunning = true
-		newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)})
+// rearmHistoryTimer configures d.historyTimer to fire when the
+// next item in d.history expires.
+func (d *dialScheduler) rearmHistoryTimer(ch chan struct{}) {
+	if len(d.history) == 0 || d.historyTimerTime == d.history.nextExpiry() {
+		return
 	}
+	d.stopHistoryTimer(ch)
+	d.historyTimerTime = d.history.nextExpiry()
+	timeout := time.Duration(d.historyTimerTime - d.clock.Now())
+	d.historyTimer = d.clock.AfterFunc(timeout, func() { ch <- struct{}{} })
+}
 
-	// Launch a timer to wait for the next node to expire if all
-	// candidates have been tried and no task is currently active.
-	// This should prevent cases where the dialer logic is not ticked
-	// because there are no pending events.
-	if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
-		t := &waitExpireTask{s.hist.nextExpiry().Sub(now)}
-		newtasks = append(newtasks, t)
+// stopHistoryTimer stops the timer and drains the channel it sends on.
+func (d *dialScheduler) stopHistoryTimer(ch chan struct{}) {
+	if d.historyTimer != nil && !d.historyTimer.Stop() {
+		<-ch
 	}
-	return newtasks
 }
 
-var (
-	errSelf             = errors.New("is self")
-	errAlreadyDialing   = errors.New("already dialing")
-	errAlreadyConnected = errors.New("already connected")
-	errRecentlyDialed   = errors.New("recently dialed")
-	errNotWhitelisted   = errors.New("not contained in netrestrict whitelist")
-)
+// expireHistory removes expired items from d.history.
+func (d *dialScheduler) expireHistory() {
+	d.historyTimer.Stop()
+	d.historyTimer = nil
+	d.historyTimerTime = 0
+	d.history.expire(d.clock.Now(), func(hkey string) {
+		var id enode.ID
+		copy(id[:], hkey)
+		d.updateStaticPool(id)
+	})
+}
+
+// freeDialSlots returns the number of free dial slots. The result can be negative
+// when peers are connected while their task is still running.
+func (d *dialScheduler) freeDialSlots() int {
+	slots := (d.maxDialPeers - d.dialPeers) * 2
+	if slots > d.maxActiveDials {
+		slots = d.maxActiveDials
+	}
+	free := slots - len(d.dialing)
+	return free
+}
 
-func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
-	_, dialing := s.dialing[n.ID()]
-	switch {
-	case dialing:
+// checkDial returns an error if node n should not be dialed.
+func (d *dialScheduler) checkDial(n *enode.Node) error {
+	if n.ID() == d.self {
+		return errSelf
+	}
+	if _, ok := d.dialing[n.ID()]; ok {
 		return errAlreadyDialing
-	case peers[n.ID()] != nil:
+	}
+	if _, ok := d.peers[n.ID()]; ok {
 		return errAlreadyConnected
-	case n.ID() == s.self:
-		return errSelf
-	case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
+	}
+	if d.netRestrict != nil && !d.netRestrict.Contains(n.IP()) {
 		return errNotWhitelisted
-	case s.hist.contains(string(n.ID().Bytes())):
+	}
+	if d.history.contains(string(n.ID().Bytes())) {
 		return errRecentlyDialed
 	}
 	return nil
 }
 
-func (s *dialstate) taskDone(t task, now time.Time) {
-	switch t := t.(type) {
-	case *dialTask:
-		s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration))
-		delete(s.dialing, t.dest.ID())
-	case *discoverTask:
-		s.lookupRunning = false
-		s.lookupBuf = append(s.lookupBuf, t.results...)
+// startStaticDials starts n static dial tasks.
+func (d *dialScheduler) startStaticDials(n int) (started int) {
+	for started = 0; started < n && len(d.staticPool) > 0; started++ {
+		idx := d.rand.Intn(len(d.staticPool))
+		task := d.staticPool[idx]
+		d.startDial(task)
+		d.removeFromStaticPool(idx)
+	}
+	return started
+}
+
+// updateStaticPool attempts to move the given static dial back into staticPool.
+func (d *dialScheduler) updateStaticPool(id enode.ID) {
+	task, ok := d.static[id]
+	if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil {
+		d.addToStaticPool(task)
+	}
+}
+
+func (d *dialScheduler) addToStaticPool(task *dialTask) {
+	if task.staticPoolIndex >= 0 {
+		panic("attempt to add task to staticPool twice")
 	}
+	d.staticPool = append(d.staticPool, task)
+	task.staticPoolIndex = len(d.staticPool) - 1
 }
 
-// A dialTask is generated for each node that is dialed. Its
-// fields cannot be accessed while the task is running.
+// removeFromStaticPool removes the task at idx from staticPool. It does that by moving the
+// current last element of the pool to idx and then shortening the pool by one.
+func (d *dialScheduler) removeFromStaticPool(idx int) {
+	task := d.staticPool[idx]
+	end := len(d.staticPool) - 1
+	d.staticPool[idx] = d.staticPool[end]
+	d.staticPool[idx].staticPoolIndex = idx
+	d.staticPool[end] = nil
+	d.staticPool = d.staticPool[:end]
+	task.staticPoolIndex = -1
+}
+
+// startDial runs the given dial task in a separate goroutine.
+func (d *dialScheduler) startDial(task *dialTask) {
+	d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags)
+	hkey := string(task.dest.ID().Bytes())
+	d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
+	d.dialing[task.dest.ID()] = task
+	go func() {
+		task.run(d)
+		d.doneCh <- task
+	}()
+}
+
+// A dialTask generated for each node that is dialed.
 type dialTask struct {
-	flags        connFlag
+	staticPoolIndex int
+	flags           connFlag
+	// These fields are private to the task and should not be
+	// accessed by dialScheduler while the task is running.
 	dest         *enode.Node
-	lastResolved time.Time
+	lastResolved mclock.AbsTime
 	resolveDelay time.Duration
 }
 
-func (t *dialTask) Do(srv *Server) {
+func newDialTask(dest *enode.Node, flags connFlag) *dialTask {
+	return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1}
+}
+
+type dialError struct {
+	error
+}
+
+func (t *dialTask) run(d *dialScheduler) {
 	if t.dest.Incomplete() {
-		if !t.resolve(srv) {
+		if !t.resolve(d) {
 			return
 		}
 	}
-	err := t.dial(srv, t.dest)
+
+	err := t.dial(d, t.dest)
 	if err != nil {
-		srv.log.Trace("Dial error", "task", t, "err", err)
 		// Try resolving the ID of static nodes if dialing failed.
 		if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
-			if t.resolve(srv) {
-				t.dial(srv, t.dest)
+			if t.resolve(d) {
+				t.dial(d, t.dest)
 			}
 		}
 	}
@@ -266,46 +497,42 @@ func (t *dialTask) Do(srv *Server) {
 // Resolve operations are throttled with backoff to avoid flooding the
 // discovery network with useless queries for nodes that don't exist.
 // The backoff delay resets when the node is found.
-func (t *dialTask) resolve(srv *Server) bool {
-	if srv.staticNodeResolver == nil {
-		srv.log.Debug("Can't resolve node", "id", t.dest.ID(), "err", "discovery is disabled")
+func (t *dialTask) resolve(d *dialScheduler) bool {
+	if d.resolver == nil {
 		return false
 	}
 	if t.resolveDelay == 0 {
 		t.resolveDelay = initialResolveDelay
 	}
-	if time.Since(t.lastResolved) < t.resolveDelay {
+	if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay {
 		return false
 	}
-	resolved := srv.staticNodeResolver.Resolve(t.dest)
-	t.lastResolved = time.Now()
+	resolved := d.resolver.Resolve(t.dest)
+	t.lastResolved = d.clock.Now()
 	if resolved == nil {
 		t.resolveDelay *= 2
 		if t.resolveDelay > maxResolveDelay {
 			t.resolveDelay = maxResolveDelay
 		}
-		srv.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay)
+		d.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay)
 		return false
 	}
 	// The node was found.
 	t.resolveDelay = initialResolveDelay
 	t.dest = resolved
-	srv.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
+	d.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
 	return true
 }
 
-type dialError struct {
-	error
-}
-
 // dial performs the actual connection attempt.
-func (t *dialTask) dial(srv *Server, dest *enode.Node) error {
-	fd, err := srv.Dialer.Dial(dest)
+func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
+	fd, err := d.dialer.Dial(d.ctx, t.dest)
 	if err != nil {
+		d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanupDialErr(err))
 		return &dialError{err}
 	}
 	mfd := newMeteredConn(fd, false, &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()})
-	return srv.SetupConn(mfd, t.flags, dest)
+	return d.setupFunc(mfd, t.flags, dest)
 }
 
 func (t *dialTask) String() string {
@@ -313,37 +540,9 @@ func (t *dialTask) String() string {
 	return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP())
 }
 
-// discoverTask runs discovery table operations.
-// Only one discoverTask is active at any time.
-// discoverTask.Do performs a random lookup.
-type discoverTask struct {
-	want    int
-	results []*enode.Node
-}
-
-func (t *discoverTask) Do(srv *Server) {
-	t.results = enode.ReadNodes(srv.discmix, t.want)
-}
-
-func (t *discoverTask) String() string {
-	s := "discovery query"
-	if len(t.results) > 0 {
-		s += fmt.Sprintf(" (%d results)", len(t.results))
-	} else {
-		s += fmt.Sprintf(" (want %d)", t.want)
+func cleanupDialErr(err error) error {
+	if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" {
+		return netErr.Err
 	}
-	return s
-}
-
-// A waitExpireTask is generated if there are no other tasks
-// to keep the loop in Server.run ticking.
-type waitExpireTask struct {
-	time.Duration
-}
-
-func (t waitExpireTask) Do(*Server) {
-	time.Sleep(t.Duration)
-}
-func (t waitExpireTask) String() string {
-	return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
+	return err
 }

+ 583 - 501
p2p/dial_test.go

@@ -17,574 +17,656 @@
 package p2p
 
 import (
-	"encoding/binary"
+	"context"
+	"errors"
+	"fmt"
+	"math/rand"
 	"net"
 	"reflect"
-	"strings"
+	"sync"
 	"testing"
 	"time"
 
-	"github.com/davecgh/go-spew/spew"
+	"github.com/ethereum/go-ethereum/common/mclock"
 	"github.com/ethereum/go-ethereum/internal/testlog"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/p2p/enode"
-	"github.com/ethereum/go-ethereum/p2p/enr"
 	"github.com/ethereum/go-ethereum/p2p/netutil"
 )
 
-func init() {
-	spew.Config.Indent = "\t"
-}
-
-type dialtest struct {
-	init   *dialstate // state before and after the test.
-	rounds []round
-}
-
-type round struct {
-	peers []*Peer // current peer set
-	done  []task  // tasks that got done this round
-	new   []task  // the result must match this one
-}
+// This test checks that dynamic dials are launched from discovery results.
+func TestDialSchedDynDial(t *testing.T) {
+	t.Parallel()
 
-func runDialTest(t *testing.T, test dialtest) {
-	var (
-		vtime   time.Time
-		running int
-	)
-	pm := func(ps []*Peer) map[enode.ID]*Peer {
-		m := make(map[enode.ID]*Peer)
-		for _, p := range ps {
-			m[p.ID()] = p
-		}
-		return m
+	config := dialConfig{
+		maxActiveDials: 5,
+		maxDialPeers:   4,
 	}
-	for i, round := range test.rounds {
-		for _, task := range round.done {
-			running--
-			if running < 0 {
-				panic("running task counter underflow")
-			}
-			test.init.taskDone(task, vtime)
-		}
+	runDialTest(t, config, []dialTestRound{
+		// 3 out of 4 peers are connected, leaving 2 dial slots.
+		// 9 nodes are discovered, but only 2 are dialed.
+		{
+			peersAdded: []*conn{
+				{flags: staticDialedConn, node: newNode(uintID(0x00), "")},
+				{flags: dynDialedConn, node: newNode(uintID(0x01), "")},
+				{flags: dynDialedConn, node: newNode(uintID(0x02), "")},
+			},
+			discovered: []*enode.Node{
+				newNode(uintID(0x00), "127.0.0.1:30303"), // not dialed because already connected as static peer
+				newNode(uintID(0x02), "127.0.0.1:30303"), // ...
+				newNode(uintID(0x03), "127.0.0.1:30303"),
+				newNode(uintID(0x04), "127.0.0.1:30303"),
+				newNode(uintID(0x05), "127.0.0.1:30303"), // not dialed because there are only two slots
+				newNode(uintID(0x06), "127.0.0.1:30303"), // ...
+				newNode(uintID(0x07), "127.0.0.1:30303"), // ...
+				newNode(uintID(0x08), "127.0.0.1:30303"), // ...
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x03), "127.0.0.1:30303"),
+				newNode(uintID(0x04), "127.0.0.1:30303"),
+			},
+		},
 
-		new := test.init.newTasks(running, pm(round.peers), vtime)
-		if !sametasks(new, round.new) {
-			t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v",
-				i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
-		}
-		t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new)))
+		// One dial completes, freeing one dial slot.
+		{
+			failed: []enode.ID{
+				uintID(0x04),
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x05), "127.0.0.1:30303"),
+			},
+		},
 
-		// Time advances by 16 seconds on every round.
-		vtime = vtime.Add(16 * time.Second)
-		running += len(new)
-	}
-}
+		// Dial to 0x03 completes, filling the last remaining peer slot.
+		{
+			succeeded: []enode.ID{
+				uintID(0x03),
+			},
+			failed: []enode.ID{
+				uintID(0x05),
+			},
+			discovered: []*enode.Node{
+				newNode(uintID(0x09), "127.0.0.1:30303"), // not dialed because there are no free slots
+			},
+		},
 
-// This test checks that dynamic dials are launched from discovery results.
-func TestDialStateDynDial(t *testing.T) {
-	config := &Config{Logger: testlog.Logger(t, log.LvlTrace)}
-	runDialTest(t, dialtest{
-		init: newDialState(enode.ID{}, 5, config),
-		rounds: []round{
-			// A discovery query is launched.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
-				},
-				new: []task{
-					&discoverTask{want: 3},
-				},
-			},
-			// Dynamic dials are launched when it completes.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
-				},
-				done: []task{
-					&discoverTask{results: []*enode.Node{
-						newNode(uintID(2), nil), // this one is already connected and not dialed.
-						newNode(uintID(3), nil),
-						newNode(uintID(4), nil),
-						newNode(uintID(5), nil),
-						newNode(uintID(6), nil), // these are not tried because max dyn dials is 5
-						newNode(uintID(7), nil), // ...
-					}},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
-				},
-			},
-			// Some of the dials complete but no new ones are launched yet because
-			// the sum of active dial count and dynamic peer count is == maxDynDials.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
-				},
-				done: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
-				},
-			},
-			// No new dial tasks are launched in the this round because
-			// maxDynDials has been reached.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
-				},
-				done: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
-				},
-				new: []task{
-					&waitExpireTask{Duration: 19 * time.Second},
-				},
-			},
-			// In this round, the peer with id 2 drops off. The query
-			// results from last discovery lookup are reused.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)},
-				},
-			},
-			// More peers (3,4) drop off and dial for ID 6 completes.
-			// The last query result from the discovery lookup is reused
-			// and a new one is spawned because more candidates are needed.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
-				},
-				done: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)},
-					&discoverTask{want: 2},
-				},
-			},
-			// Peer 7 is connected, but there still aren't enough dynamic peers
-			// (4 out of 5). However, a discovery is already running, so ensure
-			// no new is started.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}},
-				},
-				done: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)},
-				},
-			},
-			// Finish the running node discovery with an empty set. A new lookup
-			// should be immediately requested.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}},
-				},
-				done: []task{
-					&discoverTask{},
-				},
-				new: []task{
-					&discoverTask{want: 2},
-				},
+		// 3 peers drop off, creating 6 dial slots. Check that 5 of those slots
+		// (i.e. up to maxActiveDialTasks) are used.
+		{
+			peersRemoved: []enode.ID{
+				uintID(0x00),
+				uintID(0x01),
+				uintID(0x02),
+			},
+			discovered: []*enode.Node{
+				newNode(uintID(0x0a), "127.0.0.1:30303"),
+				newNode(uintID(0x0b), "127.0.0.1:30303"),
+				newNode(uintID(0x0c), "127.0.0.1:30303"),
+				newNode(uintID(0x0d), "127.0.0.1:30303"),
+				newNode(uintID(0x0f), "127.0.0.1:30303"),
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x06), "127.0.0.1:30303"),
+				newNode(uintID(0x07), "127.0.0.1:30303"),
+				newNode(uintID(0x08), "127.0.0.1:30303"),
+				newNode(uintID(0x09), "127.0.0.1:30303"),
+				newNode(uintID(0x0a), "127.0.0.1:30303"),
 			},
 		},
 	})
 }
 
-// Tests that bootnodes are dialed if no peers are connectd, but not otherwise.
-func TestDialStateDynDialBootnode(t *testing.T) {
-	config := &Config{
-		BootstrapNodes: []*enode.Node{
-			newNode(uintID(1), nil),
-			newNode(uintID(2), nil),
-			newNode(uintID(3), nil),
-		},
-		Logger: testlog.Logger(t, log.LvlTrace),
+// This test checks that candidates that do not match the netrestrict list are not dialed.
+func TestDialSchedNetRestrict(t *testing.T) {
+	t.Parallel()
+
+	nodes := []*enode.Node{
+		newNode(uintID(0x01), "127.0.0.1:30303"),
+		newNode(uintID(0x02), "127.0.0.2:30303"),
+		newNode(uintID(0x03), "127.0.0.3:30303"),
+		newNode(uintID(0x04), "127.0.0.4:30303"),
+		newNode(uintID(0x05), "127.0.2.5:30303"),
+		newNode(uintID(0x06), "127.0.2.6:30303"),
+		newNode(uintID(0x07), "127.0.2.7:30303"),
+		newNode(uintID(0x08), "127.0.2.8:30303"),
+	}
+	config := dialConfig{
+		netRestrict:    new(netutil.Netlist),
+		maxActiveDials: 10,
+		maxDialPeers:   10,
 	}
-	runDialTest(t, dialtest{
-		init: newDialState(enode.ID{}, 5, config),
-		rounds: []round{
-			{
-				new: []task{
-					&discoverTask{want: 5},
-				},
-			},
-			{
-				done: []task{
-					&discoverTask{
-						results: []*enode.Node{
-							newNode(uintID(4), nil),
-							newNode(uintID(5), nil),
-						},
-					},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
-					&discoverTask{want: 3},
-				},
-			},
-			// No dials succeed, bootnodes still pending fallback interval
-			{},
-			// 1 bootnode attempted as fallback interval was reached
-			{
-				done: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)},
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
-				},
-			},
-			// No dials succeed, 2nd bootnode is attempted
-			{
-				done: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
-				},
-			},
-			// No dials succeed, 3rd bootnode is attempted
-			{
-				done: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
-				},
-			},
-			// No dials succeed, 1st bootnode is attempted again, expired random nodes retried
-			{
-				done: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)},
-					&discoverTask{results: []*enode.Node{
-						newNode(uintID(6), nil),
-					}},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)},
-					&discoverTask{want: 4},
-				},
-			},
-			// Random dial succeeds, no more bootnodes are attempted
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}},
-				},
+	config.netRestrict.Add("127.0.2.0/24")
+	runDialTest(t, config, []dialTestRound{
+		{
+			discovered:   nodes,
+			wantNewDials: nodes[4:8],
+		},
+		{
+			succeeded: []enode.ID{
+				nodes[4].ID(),
+				nodes[5].ID(),
+				nodes[6].ID(),
+				nodes[7].ID(),
 			},
 		},
 	})
 }
 
-func newNode(id enode.ID, ip net.IP) *enode.Node {
-	var r enr.Record
-	if ip != nil {
-		r.Set(enr.IP(ip))
+// This test checks that static dials work and obey the limits.
+func TestDialSchedStaticDial(t *testing.T) {
+	t.Parallel()
+
+	config := dialConfig{
+		maxActiveDials: 5,
+		maxDialPeers:   4,
 	}
-	return enode.SignNull(&r, id)
+	runDialTest(t, config, []dialTestRound{
+		// Static dials are launched for the nodes that
+		// aren't yet connected.
+		{
+			peersAdded: []*conn{
+				{flags: dynDialedConn, node: newNode(uintID(0x01), "127.0.0.1:30303")},
+				{flags: dynDialedConn, node: newNode(uintID(0x02), "127.0.0.2:30303")},
+			},
+			update: func(d *dialScheduler) {
+				// These two are not dialed because they're already connected
+				// as dynamic peers.
+				d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303"))
+				d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303"))
+				// These nodes will be dialed:
+				d.addStatic(newNode(uintID(0x03), "127.0.0.3:30303"))
+				d.addStatic(newNode(uintID(0x04), "127.0.0.4:30303"))
+				d.addStatic(newNode(uintID(0x05), "127.0.0.5:30303"))
+				d.addStatic(newNode(uintID(0x06), "127.0.0.6:30303"))
+				d.addStatic(newNode(uintID(0x07), "127.0.0.7:30303"))
+				d.addStatic(newNode(uintID(0x08), "127.0.0.8:30303"))
+				d.addStatic(newNode(uintID(0x09), "127.0.0.9:30303"))
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x03), "127.0.0.3:30303"),
+				newNode(uintID(0x04), "127.0.0.4:30303"),
+				newNode(uintID(0x05), "127.0.0.5:30303"),
+				newNode(uintID(0x06), "127.0.0.6:30303"),
+			},
+		},
+		// Dial to 0x03 completes, filling a peer slot. One slot remains,
+		// two dials are launched to attempt to fill it.
+		{
+			succeeded: []enode.ID{
+				uintID(0x03),
+			},
+			failed: []enode.ID{
+				uintID(0x04),
+				uintID(0x05),
+				uintID(0x06),
+			},
+			wantResolves: map[enode.ID]*enode.Node{
+				uintID(0x04): nil,
+				uintID(0x05): nil,
+				uintID(0x06): nil,
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x08), "127.0.0.8:30303"),
+				newNode(uintID(0x09), "127.0.0.9:30303"),
+			},
+		},
+		// Peer 0x01 drops and 0x07 connects as inbound peer.
+		// Only 0x01 is dialed.
+		{
+			peersAdded: []*conn{
+				{flags: inboundConn, node: newNode(uintID(0x07), "127.0.0.7:30303")},
+			},
+			peersRemoved: []enode.ID{
+				uintID(0x01),
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x01), "127.0.0.1:30303"),
+			},
+		},
+	})
 }
 
-// // This test checks that candidates that do not match the netrestrict list are not dialed.
-func TestDialStateNetRestrict(t *testing.T) {
-	// This table always returns the same random nodes
-	// in the order given below.
-	nodes := []*enode.Node{
-		newNode(uintID(1), net.ParseIP("127.0.0.1")),
-		newNode(uintID(2), net.ParseIP("127.0.0.2")),
-		newNode(uintID(3), net.ParseIP("127.0.0.3")),
-		newNode(uintID(4), net.ParseIP("127.0.0.4")),
-		newNode(uintID(5), net.ParseIP("127.0.2.5")),
-		newNode(uintID(6), net.ParseIP("127.0.2.6")),
-		newNode(uintID(7), net.ParseIP("127.0.2.7")),
-		newNode(uintID(8), net.ParseIP("127.0.2.8")),
+// This test checks that removing static nodes stops connecting to them.
+func TestDialSchedRemoveStatic(t *testing.T) {
+	t.Parallel()
+
+	config := dialConfig{
+		maxActiveDials: 1,
+		maxDialPeers:   1,
 	}
-	restrict := new(netutil.Netlist)
-	restrict.Add("127.0.2.0/24")
-
-	runDialTest(t, dialtest{
-		init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}),
-		rounds: []round{
-			{
-				new: []task{
-					&discoverTask{want: 10},
-				},
-			},
-			{
-				done: []task{
-					&discoverTask{results: nodes},
-				},
-				new: []task{
-					&dialTask{flags: dynDialedConn, dest: nodes[4]},
-					&dialTask{flags: dynDialedConn, dest: nodes[5]},
-					&dialTask{flags: dynDialedConn, dest: nodes[6]},
-					&dialTask{flags: dynDialedConn, dest: nodes[7]},
-					&discoverTask{want: 6},
-				},
+	runDialTest(t, config, []dialTestRound{
+		// Add static nodes.
+		{
+			update: func(d *dialScheduler) {
+				d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303"))
+				d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303"))
+				d.addStatic(newNode(uintID(0x03), "127.0.0.3:30303"))
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x01), "127.0.0.1:30303"),
+			},
+		},
+		// Dial to 0x01 fails.
+		{
+			failed: []enode.ID{
+				uintID(0x01),
+			},
+			wantResolves: map[enode.ID]*enode.Node{
+				uintID(0x01): nil,
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x02), "127.0.0.2:30303"),
+			},
+		},
+		// All static nodes are removed. 0x01 is in history, 0x02 is being
+		// dialed, 0x03 is in staticPool.
+		{
+			update: func(d *dialScheduler) {
+				d.removeStatic(newNode(uintID(0x01), "127.0.0.1:30303"))
+				d.removeStatic(newNode(uintID(0x02), "127.0.0.2:30303"))
+				d.removeStatic(newNode(uintID(0x03), "127.0.0.3:30303"))
+			},
+			failed: []enode.ID{
+				uintID(0x02),
+			},
+			wantResolves: map[enode.ID]*enode.Node{
+				uintID(0x02): nil,
 			},
 		},
+		// Since all static nodes are removed, they should not be dialed again.
+		{}, {}, {},
 	})
 }
 
-// This test checks that static dials are launched.
-func TestDialStateStaticDial(t *testing.T) {
-	config := &Config{
-		StaticNodes: []*enode.Node{
-			newNode(uintID(1), nil),
-			newNode(uintID(2), nil),
-			newNode(uintID(3), nil),
-			newNode(uintID(4), nil),
-			newNode(uintID(5), nil),
+// This test checks that static dials are selected at random.
+func TestDialSchedManyStaticNodes(t *testing.T) {
+	t.Parallel()
+
+	config := dialConfig{maxDialPeers: 2}
+	runDialTest(t, config, []dialTestRound{
+		{
+			peersAdded: []*conn{
+				{flags: dynDialedConn, node: newNode(uintID(0xFFFE), "")},
+				{flags: dynDialedConn, node: newNode(uintID(0xFFFF), "")},
+			},
+			update: func(d *dialScheduler) {
+				for id := uint16(0); id < 2000; id++ {
+					n := newNode(uintID(id), "127.0.0.1:30303")
+					d.addStatic(n)
+				}
+			},
 		},
-		Logger: testlog.Logger(t, log.LvlTrace),
-	}
-	runDialTest(t, dialtest{
-		init: newDialState(enode.ID{}, 0, config),
-		rounds: []round{
-			// Static dials are launched for the nodes that
-			// aren't yet connected.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
-				},
-				new: []task{
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
-				},
-			},
-			// No new tasks are launched in this round because all static
-			// nodes are either connected or still being dialed.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
-				},
-				done: []task{
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
-				},
-			},
-			// No new dial tasks are launched because all static
-			// nodes are now connected.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}},
-				},
-				done: []task{
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)},
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)},
-				},
-				new: []task{
-					&waitExpireTask{Duration: 19 * time.Second},
-				},
-			},
-			// Wait a round for dial history to expire, no new tasks should spawn.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}},
-				},
-			},
-			// If a static node is dropped, it should be immediately redialed,
-			// irrespective whether it was originally static or dynamic.
-			{
-				done: []task{
-					&waitExpireTask{Duration: 19 * time.Second},
-				},
-				peers: []*Peer{
-					{rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}},
-				},
-				new: []task{
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
-				},
+		{
+			peersRemoved: []enode.ID{
+				uintID(0xFFFE),
+				uintID(0xFFFF),
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x0085), "127.0.0.1:30303"),
+				newNode(uintID(0x02dc), "127.0.0.1:30303"),
+				newNode(uintID(0x0285), "127.0.0.1:30303"),
+				newNode(uintID(0x00cb), "127.0.0.1:30303"),
 			},
 		},
 	})
 }
 
 // This test checks that past dials are not retried for some time.
-func TestDialStateCache(t *testing.T) {
-	config := &Config{
-		StaticNodes: []*enode.Node{
-			newNode(uintID(1), nil),
-			newNode(uintID(2), nil),
-			newNode(uintID(3), nil),
+func TestDialSchedHistory(t *testing.T) {
+	t.Parallel()
+
+	config := dialConfig{
+		maxActiveDials: 3,
+		maxDialPeers:   3,
+	}
+	runDialTest(t, config, []dialTestRound{
+		{
+			update: func(d *dialScheduler) {
+				d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303"))
+				d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303"))
+				d.addStatic(newNode(uintID(0x03), "127.0.0.3:30303"))
+			},
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x01), "127.0.0.1:30303"),
+				newNode(uintID(0x02), "127.0.0.2:30303"),
+				newNode(uintID(0x03), "127.0.0.3:30303"),
+			},
 		},
-		Logger: testlog.Logger(t, log.LvlTrace),
+		// No new tasks are launched in this round because all static
+		// nodes are either connected or still being dialed.
+		{
+			succeeded: []enode.ID{
+				uintID(0x01),
+				uintID(0x02),
+			},
+			failed: []enode.ID{
+				uintID(0x03),
+			},
+			wantResolves: map[enode.ID]*enode.Node{
+				uintID(0x03): nil,
+			},
+		},
+		// Nothing happens in this round because we're waiting for
+		// node 0x3's history entry to expire.
+		{},
+		// The cache entry for node 0x03 has expired and is retried.
+		{
+			wantNewDials: []*enode.Node{
+				newNode(uintID(0x03), "127.0.0.3:30303"),
+			},
+		},
+	})
+}
+
+func TestDialSchedResolve(t *testing.T) {
+	t.Parallel()
+
+	config := dialConfig{
+		maxActiveDials: 1,
+		maxDialPeers:   1,
 	}
-	runDialTest(t, dialtest{
-		init: newDialState(enode.ID{}, 0, config),
-		rounds: []round{
-			// Static dials are launched for the nodes that
-			// aren't yet connected.
-			{
-				peers: nil,
-				new: []task{
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
-				},
-			},
-			// No new tasks are launched in this round because all static
-			// nodes are either connected or still being dialed.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
-				},
-				done: []task{
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)},
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)},
-				},
-			},
-			// A salvage task is launched to wait for node 3's history
-			// entry to expire.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
-				},
-				done: []task{
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
-				},
-				new: []task{
-					&waitExpireTask{Duration: 19 * time.Second},
-				},
-			},
-			// Still waiting for node 3's entry to expire in the cache.
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
-				},
-			},
-			{
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
-				},
-			},
-			// The cache entry for node 3 has expired and is retried.
-			{
-				done: []task{
-					&waitExpireTask{Duration: 19 * time.Second},
-				},
-				peers: []*Peer{
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}},
-					{rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}},
-				},
-				new: []task{
-					&dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)},
-				},
+	node := newNode(uintID(0x01), "")
+	resolved := newNode(uintID(0x01), "127.0.0.1:30303")
+	resolved2 := newNode(uintID(0x01), "127.0.0.55:30303")
+	runDialTest(t, config, []dialTestRound{
+		{
+			update: func(d *dialScheduler) {
+				d.addStatic(node)
+			},
+			wantResolves: map[enode.ID]*enode.Node{
+				uintID(0x01): resolved,
+			},
+			wantNewDials: []*enode.Node{
+				resolved,
+			},
+		},
+		{
+			failed: []enode.ID{
+				uintID(0x01),
+			},
+			wantResolves: map[enode.ID]*enode.Node{
+				uintID(0x01): resolved2,
+			},
+			wantNewDials: []*enode.Node{
+				resolved2,
 			},
 		},
 	})
 }
 
-func TestDialResolve(t *testing.T) {
-	config := &Config{
-		Logger: testlog.Logger(t, log.LvlTrace),
-		Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}},
+// -------
+// Code below here is the framework for the tests above.
+
+type dialTestRound struct {
+	peersAdded   []*conn
+	peersRemoved []enode.ID
+	update       func(*dialScheduler) // called at beginning of round
+	discovered   []*enode.Node        // newly discovered nodes
+	succeeded    []enode.ID           // dials which succeed this round
+	failed       []enode.ID           // dials which fail this round
+	wantResolves map[enode.ID]*enode.Node
+	wantNewDials []*enode.Node // dials that should be launched in this round
+}
+
+func runDialTest(t *testing.T, config dialConfig, rounds []dialTestRound) {
+	var (
+		clock    = new(mclock.Simulated)
+		iterator = newDialTestIterator()
+		dialer   = newDialTestDialer()
+		resolver = new(dialTestResolver)
+		peers    = make(map[enode.ID]*conn)
+		setupCh  = make(chan *conn)
+	)
+
+	// Override config.
+	config.clock = clock
+	config.dialer = dialer
+	config.resolver = resolver
+	config.log = testlog.Logger(t, log.LvlTrace)
+	config.rand = rand.New(rand.NewSource(0x1111))
+
+	// Set up the dialer. The setup function below runs on the dialTask
+	// goroutine and adds the peer.
+	var dialsched *dialScheduler
+	setup := func(fd net.Conn, f connFlag, node *enode.Node) error {
+		conn := &conn{flags: f, node: node}
+		dialsched.peerAdded(conn)
+		setupCh <- conn
+		return nil
 	}
-	resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
-	resolver := &resolveMock{answer: resolved}
-	state := newDialState(enode.ID{}, 0, config)
-
-	// Check that the task is generated with an incomplete ID.
-	dest := newNode(uintID(1), nil)
-	state.addStatic(dest)
-	tasks := state.newTasks(0, nil, time.Time{})
-	if !reflect.DeepEqual(tasks, []task{&dialTask{flags: staticDialedConn, dest: dest}}) {
-		t.Fatalf("expected dial task, got %#v", tasks)
+	dialsched = newDialScheduler(config, iterator, setup)
+	defer dialsched.stop()
+
+	for i, round := range rounds {
+		// Apply peer set updates.
+		for _, c := range round.peersAdded {
+			if peers[c.node.ID()] != nil {
+				t.Fatalf("round %d: peer %v already connected", i, c.node.ID())
+			}
+			dialsched.peerAdded(c)
+			peers[c.node.ID()] = c
+		}
+		for _, id := range round.peersRemoved {
+			c := peers[id]
+			if c == nil {
+				t.Fatalf("round %d: can't remove non-existent peer %v", i, id)
+			}
+			dialsched.peerRemoved(c)
+		}
+
+		// Init round.
+		t.Logf("round %d (%d peers)", i, len(peers))
+		resolver.setAnswers(round.wantResolves)
+		if round.update != nil {
+			round.update(dialsched)
+		}
+		iterator.addNodes(round.discovered)
+
+		// Unblock dialTask goroutines.
+		if err := dialer.completeDials(round.succeeded, nil); err != nil {
+			t.Fatalf("round %d: %v", i, err)
+		}
+		for range round.succeeded {
+			conn := <-setupCh
+			peers[conn.node.ID()] = conn
+		}
+		if err := dialer.completeDials(round.failed, errors.New("oops")); err != nil {
+			t.Fatalf("round %d: %v", i, err)
+		}
+
+		// Wait for new tasks.
+		if err := dialer.waitForDials(round.wantNewDials); err != nil {
+			t.Fatalf("round %d: %v", i, err)
+		}
+		if !resolver.checkCalls() {
+			t.Fatalf("unexpected calls to Resolve: %v", resolver.calls)
+		}
+
+		clock.Run(16 * time.Second)
 	}
+}
 
-	// Now run the task, it should resolve the ID once.
-	srv := &Server{
-		Config:             *config,
-		log:                config.Logger,
-		staticNodeResolver: resolver,
+// dialTestIterator is the input iterator for dialer tests. This works a bit like a channel
+// with infinite buffer: nodes are added to the buffer with addNodes, which unblocks Next
+// and returns them from the iterator.
+type dialTestIterator struct {
+	cur *enode.Node
+
+	mu     sync.Mutex
+	buf    []*enode.Node
+	cond   *sync.Cond
+	closed bool
+}
+
+func newDialTestIterator() *dialTestIterator {
+	it := &dialTestIterator{}
+	it.cond = sync.NewCond(&it.mu)
+	return it
+}
+
+// addNodes adds nodes to the iterator buffer and unblocks Next.
+func (it *dialTestIterator) addNodes(nodes []*enode.Node) {
+	it.mu.Lock()
+	defer it.mu.Unlock()
+
+	it.buf = append(it.buf, nodes...)
+	it.cond.Signal()
+}
+
+// Node returns the current node.
+func (it *dialTestIterator) Node() *enode.Node {
+	return it.cur
+}
+
+// Next moves to the next node.
+func (it *dialTestIterator) Next() bool {
+	it.mu.Lock()
+	defer it.mu.Unlock()
+
+	it.cur = nil
+	for len(it.buf) == 0 && !it.closed {
+		it.cond.Wait()
 	}
-	tasks[0].Do(srv)
-	if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) {
-		t.Fatalf("wrong resolve calls, got %v", resolver.calls)
+	if it.closed {
+		return false
 	}
+	it.cur = it.buf[0]
+	copy(it.buf[:], it.buf[1:])
+	it.buf = it.buf[:len(it.buf)-1]
+	return true
+}
+
+// Close ends the iterator, unblocking Next.
+func (it *dialTestIterator) Close() {
+	it.mu.Lock()
+	defer it.mu.Unlock()
+
+	it.closed = true
+	it.buf = nil
+	it.cond.Signal()
+}
+
+// dialTestDialer is the NodeDialer used by runDialTest.
+type dialTestDialer struct {
+	init    chan *dialTestReq
+	blocked map[enode.ID]*dialTestReq
+}
 
-	// Report it as done to the dialer, which should update the static node record.
-	state.taskDone(tasks[0], time.Now())
-	if state.static[uintID(1)].dest != resolved {
-		t.Fatalf("state.dest not updated")
+type dialTestReq struct {
+	n       *enode.Node
+	unblock chan error
+}
+
+func newDialTestDialer() *dialTestDialer {
+	return &dialTestDialer{
+		init:    make(chan *dialTestReq),
+		blocked: make(map[enode.ID]*dialTestReq),
 	}
 }
 
-// compares task lists but doesn't care about the order.
-func sametasks(a, b []task) bool {
-	if len(a) != len(b) {
-		return false
+// Dial implements NodeDialer.
+func (d *dialTestDialer) Dial(ctx context.Context, n *enode.Node) (net.Conn, error) {
+	req := &dialTestReq{n: n, unblock: make(chan error, 1)}
+	select {
+	case d.init <- req:
+		select {
+		case err := <-req.unblock:
+			pipe, _ := net.Pipe()
+			return pipe, err
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		}
+	case <-ctx.Done():
+		return nil, ctx.Err()
 	}
-next:
-	for _, ta := range a {
-		for _, tb := range b {
-			if reflect.DeepEqual(ta, tb) {
-				continue next
+}
+
+// waitForDials waits for calls to Dial with the given nodes as argument.
+// Those calls will be held blocking until completeDials is called with the same nodes.
+func (d *dialTestDialer) waitForDials(nodes []*enode.Node) error {
+	waitset := make(map[enode.ID]*enode.Node)
+	for _, n := range nodes {
+		waitset[n.ID()] = n
+	}
+	timeout := time.NewTimer(1 * time.Second)
+	defer timeout.Stop()
+
+	for len(waitset) > 0 {
+		select {
+		case req := <-d.init:
+			want, ok := waitset[req.n.ID()]
+			if !ok {
+				return fmt.Errorf("attempt to dial unexpected node %v", req.n.ID())
+			}
+			if !reflect.DeepEqual(req.n, want) {
+				return fmt.Errorf("ENR of dialed node %v does not match test", req.n.ID())
 			}
+			delete(waitset, req.n.ID())
+			d.blocked[req.n.ID()] = req
+		case <-timeout.C:
+			var waitlist []enode.ID
+			for id := range waitset {
+				waitlist = append(waitlist, id)
+			}
+			return fmt.Errorf("timed out waiting for dials to %v", waitlist)
 		}
-		return false
 	}
-	return true
+
+	return d.checkUnexpectedDial()
+}
+
+func (d *dialTestDialer) checkUnexpectedDial() error {
+	select {
+	case req := <-d.init:
+		return fmt.Errorf("attempt to dial unexpected node %v", req.n.ID())
+	case <-time.After(150 * time.Millisecond):
+		return nil
+	}
 }
 
-func uintID(i uint32) enode.ID {
-	var id enode.ID
-	binary.BigEndian.PutUint32(id[:], i)
-	return id
+// completeDials unblocks calls to Dial for the given nodes.
+func (d *dialTestDialer) completeDials(ids []enode.ID, err error) error {
+	for _, id := range ids {
+		req := d.blocked[id]
+		if req == nil {
+			return fmt.Errorf("can't complete dial to %v", id)
+		}
+		req.unblock <- err
+	}
+	return nil
 }
 
-// for TestDialResolve
-type resolveMock struct {
-	calls  []*enode.Node
-	answer *enode.Node
+// dialTestResolver tracks calls to resolve.
+type dialTestResolver struct {
+	mu      sync.Mutex
+	calls   []enode.ID
+	answers map[enode.ID]*enode.Node
 }
 
-func (t *resolveMock) Resolve(n *enode.Node) *enode.Node {
-	t.calls = append(t.calls, n)
-	return t.answer
+func (t *dialTestResolver) setAnswers(m map[enode.ID]*enode.Node) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	t.answers = m
+	t.calls = nil
+}
+
+func (t *dialTestResolver) checkCalls() bool {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	for _, id := range t.calls {
+		if _, ok := t.answers[id]; !ok {
+			return false
+		}
+	}
+	return true
+}
+
+func (t *dialTestResolver) Resolve(n *enode.Node) *enode.Node {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	t.calls = append(t.calls, n.ID())
+	return t.answers[n.ID()]
 }

+ 42 - 2
p2p/peer_test.go

@@ -17,15 +17,20 @@
 package p2p
 
 import (
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"math/rand"
 	"net"
 	"reflect"
+	"strconv"
+	"strings"
 	"testing"
 	"time"
 
 	"github.com/ethereum/go-ethereum/log"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/p2p/enr"
 )
 
 var discard = Protocol{
@@ -45,10 +50,45 @@ var discard = Protocol{
 	},
 }
 
+// uintID encodes i into a node ID.
+func uintID(i uint16) enode.ID {
+	var id enode.ID
+	binary.BigEndian.PutUint16(id[:], i)
+	return id
+}
+
+// newNode creates a node record with the given address.
+func newNode(id enode.ID, addr string) *enode.Node {
+	var r enr.Record
+	if addr != "" {
+		// Set the port if present.
+		if strings.Contains(addr, ":") {
+			hs, ps, err := net.SplitHostPort(addr)
+			if err != nil {
+				panic(fmt.Errorf("invalid address %q", addr))
+			}
+			port, err := strconv.Atoi(ps)
+			if err != nil {
+				panic(fmt.Errorf("invalid port in %q", addr))
+			}
+			r.Set(enr.TCP(port))
+			r.Set(enr.UDP(port))
+			addr = hs
+		}
+		// Set the IP.
+		ip := net.ParseIP(addr)
+		if ip == nil {
+			panic(fmt.Errorf("invalid IP %q", addr))
+		}
+		r.Set(enr.IP(ip))
+	}
+	return enode.SignNull(&r, id)
+}
+
 func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
 	fd1, fd2 := net.Pipe()
-	c1 := &conn{fd: fd1, node: newNode(randomID(), nil), transport: newTestTransport(&newkey().PublicKey, fd1)}
-	c2 := &conn{fd: fd2, node: newNode(randomID(), nil), transport: newTestTransport(&newkey().PublicKey, fd2)}
+	c1 := &conn{fd: fd1, node: newNode(randomID(), ""), transport: newTestTransport(&newkey().PublicKey, fd1)}
+	c2 := &conn{fd: fd2, node: newNode(randomID(), ""), transport: newTestTransport(&newkey().PublicKey, fd2)}
 	for _, p := range protos {
 		c1.caps = append(c1.caps, p.cap())
 		c2.caps = append(c2.caps, p.cap())

+ 144 - 176
p2p/server.go

@@ -51,7 +51,6 @@ const (
 	discmixTimeout = 5 * time.Second
 
 	// Connectivity defaults.
-	maxActiveDialTasks     = 16
 	defaultMaxPendingPeers = 50
 	defaultDialRatio       = 3
 
@@ -156,6 +155,8 @@ type Config struct {
 
 	// Logger is a custom logger to use with the p2p.Server.
 	Logger log.Logger `toml:",omitempty"`
+
+	clock mclock.Clock
 }
 
 // Server manages all peer connections.
@@ -183,13 +184,10 @@ type Server struct {
 	ntab      *discover.UDPv4
 	DiscV5    *discv5.Network
 	discmix   *enode.FairMix
-
-	staticNodeResolver nodeResolver
+	dialsched *dialScheduler
 
 	// Channels into the run loop.
 	quit                    chan struct{}
-	addstatic               chan *enode.Node
-	removestatic            chan *enode.Node
 	addtrusted              chan *enode.Node
 	removetrusted           chan *enode.Node
 	peerOp                  chan peerOpFunc
@@ -302,47 +300,57 @@ func (srv *Server) LocalNode() *enode.LocalNode {
 // Peers returns all connected peers.
 func (srv *Server) Peers() []*Peer {
 	var ps []*Peer
-	select {
-	// Note: We'd love to put this function into a variable but
-	// that seems to cause a weird compiler error in some
-	// environments.
-	case srv.peerOp <- func(peers map[enode.ID]*Peer) {
+	srv.doPeerOp(func(peers map[enode.ID]*Peer) {
 		for _, p := range peers {
 			ps = append(ps, p)
 		}
-	}:
-		<-srv.peerOpDone
-	case <-srv.quit:
-	}
+	})
 	return ps
 }
 
 // PeerCount returns the number of connected peers.
 func (srv *Server) PeerCount() int {
 	var count int
-	select {
-	case srv.peerOp <- func(ps map[enode.ID]*Peer) { count = len(ps) }:
-		<-srv.peerOpDone
-	case <-srv.quit:
-	}
+	srv.doPeerOp(func(ps map[enode.ID]*Peer) {
+		count = len(ps)
+	})
 	return count
 }
 
-// AddPeer connects to the given node and maintains the connection until the
-// server is shut down. If the connection fails for any reason, the server will
-// attempt to reconnect the peer.
+// AddPeer adds the given node to the static node set. When there is room in the peer set,
+// the server will connect to the node. If the connection fails for any reason, the server
+// will attempt to reconnect the peer.
 func (srv *Server) AddPeer(node *enode.Node) {
-	select {
-	case srv.addstatic <- node:
-	case <-srv.quit:
-	}
+	srv.dialsched.addStatic(node)
 }
 
-// RemovePeer disconnects from the given node
+// RemovePeer removes a node from the static node set. It also disconnects from the given
+// node if it is currently connected as a peer.
+//
+// This method blocks until all protocols have exited and the peer is removed. Do not use
+// RemovePeer in protocol implementations, call Disconnect on the Peer instead.
 func (srv *Server) RemovePeer(node *enode.Node) {
-	select {
-	case srv.removestatic <- node:
-	case <-srv.quit:
+	var (
+		ch  chan *PeerEvent
+		sub event.Subscription
+	)
+	// Disconnect the peer on the main loop.
+	srv.doPeerOp(func(peers map[enode.ID]*Peer) {
+		srv.dialsched.removeStatic(node)
+		if peer := peers[node.ID()]; peer != nil {
+			ch = make(chan *PeerEvent, 1)
+			sub = srv.peerFeed.Subscribe(ch)
+			peer.Disconnect(DiscRequested)
+		}
+	})
+	// Wait for the peer connection to end.
+	if ch != nil {
+		defer sub.Unsubscribe()
+		for ev := range ch {
+			if ev.Peer == node.ID() && ev.Type == PeerEventTypeDrop {
+				return
+			}
+		}
 	}
 }
 
@@ -437,6 +445,9 @@ func (srv *Server) Start() (err error) {
 	if srv.log == nil {
 		srv.log = log.Root()
 	}
+	if srv.clock == nil {
+		srv.clock = mclock.System{}
+	}
 	if srv.NoDial && srv.ListenAddr == "" {
 		srv.log.Warn("P2P server will be useless, neither dialing nor listening")
 	}
@@ -451,15 +462,10 @@ func (srv *Server) Start() (err error) {
 	if srv.listenFunc == nil {
 		srv.listenFunc = net.Listen
 	}
-	if srv.Dialer == nil {
-		srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
-	}
 	srv.quit = make(chan struct{})
 	srv.delpeer = make(chan peerDrop)
 	srv.checkpointPostHandshake = make(chan *conn)
 	srv.checkpointAddPeer = make(chan *conn)
-	srv.addstatic = make(chan *enode.Node)
-	srv.removestatic = make(chan *enode.Node)
 	srv.addtrusted = make(chan *enode.Node)
 	srv.removetrusted = make(chan *enode.Node)
 	srv.peerOp = make(chan peerOpFunc)
@@ -476,11 +482,10 @@ func (srv *Server) Start() (err error) {
 	if err := srv.setupDiscovery(); err != nil {
 		return err
 	}
+	srv.setupDialScheduler()
 
-	dynPeers := srv.maxDialedConns()
-	dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config)
 	srv.loopWG.Add(1)
-	go srv.run(dialer)
+	go srv.run()
 	return nil
 }
 
@@ -583,7 +588,6 @@ func (srv *Server) setupDiscovery() error {
 		}
 		srv.ntab = ntab
 		srv.discmix.AddSource(ntab.RandomNodes())
-		srv.staticNodeResolver = ntab
 	}
 
 	// Discovery V5
@@ -606,6 +610,47 @@ func (srv *Server) setupDiscovery() error {
 	return nil
 }
 
+func (srv *Server) setupDialScheduler() {
+	config := dialConfig{
+		self:           srv.localnode.ID(),
+		maxDialPeers:   srv.maxDialedConns(),
+		maxActiveDials: srv.MaxPendingPeers,
+		log:            srv.Logger,
+		netRestrict:    srv.NetRestrict,
+		dialer:         srv.Dialer,
+		clock:          srv.clock,
+	}
+	if srv.ntab != nil {
+		config.resolver = srv.ntab
+	}
+	if config.dialer == nil {
+		config.dialer = tcpDialer{&net.Dialer{Timeout: defaultDialTimeout}}
+	}
+	srv.dialsched = newDialScheduler(config, srv.discmix, srv.SetupConn)
+	for _, n := range srv.StaticNodes {
+		srv.dialsched.addStatic(n)
+	}
+}
+
+func (srv *Server) maxInboundConns() int {
+	return srv.MaxPeers - srv.maxDialedConns()
+}
+
+func (srv *Server) maxDialedConns() (limit int) {
+	if srv.NoDial || srv.MaxPeers == 0 {
+		return 0
+	}
+	if srv.DialRatio == 0 {
+		limit = srv.MaxPeers / defaultDialRatio
+	} else {
+		limit = srv.MaxPeers / srv.DialRatio
+	}
+	if limit == 0 {
+		limit = 1
+	}
+	return limit
+}
+
 func (srv *Server) setupListening() error {
 	// Launch the listener.
 	listener, err := srv.listenFunc("tcp", srv.ListenAddr)
@@ -632,112 +677,55 @@ func (srv *Server) setupListening() error {
 	return nil
 }
 
-type dialer interface {
-	newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task
-	taskDone(task, time.Time)
-	addStatic(*enode.Node)
-	removeStatic(*enode.Node)
+// doPeerOp runs fn on the main loop.
+func (srv *Server) doPeerOp(fn peerOpFunc) {
+	select {
+	case srv.peerOp <- fn:
+		<-srv.peerOpDone
+	case <-srv.quit:
+	}
 }
 
-func (srv *Server) run(dialstate dialer) {
+// run is the main loop of the server.
+func (srv *Server) run() {
 	srv.log.Info("Started P2P networking", "self", srv.localnode.Node().URLv4())
 	defer srv.loopWG.Done()
 	defer srv.nodedb.Close()
 	defer srv.discmix.Close()
+	defer srv.dialsched.stop()
 
 	var (
 		peers        = make(map[enode.ID]*Peer)
 		inboundCount = 0
 		trusted      = make(map[enode.ID]bool, len(srv.TrustedNodes))
-		taskdone     = make(chan task, maxActiveDialTasks)
-		tick         = time.NewTicker(30 * time.Second)
-		runningTasks []task
-		queuedTasks  []task // tasks that can't run yet
 	)
-	defer tick.Stop()
-
 	// Put trusted nodes into a map to speed up checks.
 	// Trusted peers are loaded on startup or added via AddTrustedPeer RPC.
 	for _, n := range srv.TrustedNodes {
 		trusted[n.ID()] = true
 	}
 
-	// removes t from runningTasks
-	delTask := func(t task) {
-		for i := range runningTasks {
-			if runningTasks[i] == t {
-				runningTasks = append(runningTasks[:i], runningTasks[i+1:]...)
-				break
-			}
-		}
-	}
-	// starts until max number of active tasks is satisfied
-	startTasks := func(ts []task) (rest []task) {
-		i := 0
-		for ; len(runningTasks) < maxActiveDialTasks && i < len(ts); i++ {
-			t := ts[i]
-			srv.log.Trace("New dial task", "task", t)
-			go func() { t.Do(srv); taskdone <- t }()
-			runningTasks = append(runningTasks, t)
-		}
-		return ts[i:]
-	}
-	scheduleTasks := func() {
-		// Start from queue first.
-		queuedTasks = append(queuedTasks[:0], startTasks(queuedTasks)...)
-		// Query dialer for new tasks and start as many as possible now.
-		if len(runningTasks) < maxActiveDialTasks {
-			nt := dialstate.newTasks(len(runningTasks)+len(queuedTasks), peers, time.Now())
-			queuedTasks = append(queuedTasks, startTasks(nt)...)
-		}
-	}
-
 running:
 	for {
-		scheduleTasks()
-
 		select {
-		case <-tick.C:
-			// This is just here to ensure the dial scheduler runs occasionally.
-
 		case <-srv.quit:
 			// The server was stopped. Run the cleanup logic.
 			break running
 
-		case n := <-srv.addstatic:
-			// This channel is used by AddPeer to add to the
-			// ephemeral static peer list. Add it to the dialer,
-			// it will keep the node connected.
-			srv.log.Trace("Adding static node", "node", n)
-			dialstate.addStatic(n)
-
-		case n := <-srv.removestatic:
-			// This channel is used by RemovePeer to send a
-			// disconnect request to a peer and begin the
-			// stop keeping the node connected.
-			srv.log.Trace("Removing static node", "node", n)
-			dialstate.removeStatic(n)
-			if p, ok := peers[n.ID()]; ok {
-				p.Disconnect(DiscRequested)
-			}
-
 		case n := <-srv.addtrusted:
-			// This channel is used by AddTrustedPeer to add an enode
+			// This channel is used by AddTrustedPeer to add a node
 			// to the trusted node set.
 			srv.log.Trace("Adding trusted node", "node", n)
 			trusted[n.ID()] = true
-			// Mark any already-connected peer as trusted
 			if p, ok := peers[n.ID()]; ok {
 				p.rw.set(trustedConn, true)
 			}
 
 		case n := <-srv.removetrusted:
-			// This channel is used by RemoveTrustedPeer to remove an enode
+			// This channel is used by RemoveTrustedPeer to remove a node
 			// from the trusted node set.
 			srv.log.Trace("Removing trusted node", "node", n)
 			delete(trusted, n.ID())
-
-			// Unmark any already-connected peer as trusted
 			if p, ok := peers[n.ID()]; ok {
 				p.rw.set(trustedConn, false)
 			}
@@ -747,14 +735,6 @@ running:
 			op(peers)
 			srv.peerOpDone <- struct{}{}
 
-		case t := <-taskdone:
-			// A task got done. Tell dialstate about it so it
-			// can update its state and remove it from the active
-			// tasks list.
-			srv.log.Trace("Dial task done", "task", t)
-			dialstate.taskDone(t, time.Now())
-			delTask(t)
-
 		case c := <-srv.checkpointPostHandshake:
 			// A connection has passed the encryption handshake so
 			// the remote identity is known (but hasn't been verified yet).
@@ -771,33 +751,25 @@ running:
 			err := srv.addPeerChecks(peers, inboundCount, c)
 			if err == nil {
 				// The handshakes are done and it passed all checks.
-				p := newPeer(srv.log, c, srv.Protocols)
-				// If message events are enabled, pass the peerFeed
-				// to the peer
-				if srv.EnableMsgEvents {
-					p.events = &srv.peerFeed
-				}
-				name := truncateName(c.name)
-				p.log.Debug("Adding p2p peer", "addr", p.RemoteAddr(), "peers", len(peers)+1, "name", name)
-				go srv.runPeer(p)
+				p := srv.launchPeer(c)
 				peers[c.node.ID()] = p
-				if p.Inbound() {
-					inboundCount++
-				}
+				srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), "conn", c.flags, "addr", p.RemoteAddr(), "name", truncateName(c.name))
+				srv.dialsched.peerAdded(c)
 				if conn, ok := c.fd.(*meteredConn); ok {
 					conn.handshakeDone(p)
 				}
+				if p.Inbound() {
+					inboundCount++
+				}
 			}
-			// The dialer logic relies on the assumption that
-			// dial tasks complete after the peer has been added or
-			// discarded. Unblock the task last.
 			c.cont <- err
 
 		case pd := <-srv.delpeer:
 			// A peer disconnected.
 			d := common.PrettyDuration(mclock.Now() - pd.created)
-			pd.log.Debug("Removing p2p peer", "addr", pd.RemoteAddr(), "peers", len(peers)-1, "duration", d, "req", pd.requested, "err", pd.err)
 			delete(peers, pd.ID())
+			srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", pd.ID(), "duration", d, "req", pd.requested, "err", pd.err)
+			srv.dialsched.peerRemoved(pd.rw)
 			if pd.Inbound() {
 				inboundCount--
 			}
@@ -822,14 +794,14 @@ running:
 	// is closed.
 	for len(peers) > 0 {
 		p := <-srv.delpeer
-		p.log.Trace("<-delpeer (spindown)", "remainingTasks", len(runningTasks))
+		p.log.Trace("<-delpeer (spindown)")
 		delete(peers, p.ID())
 	}
 }
 
 func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
 	switch {
-	case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
+	case !c.is(trustedConn) && len(peers) >= srv.MaxPeers:
 		return DiscTooManyPeers
 	case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
 		return DiscTooManyPeers
@@ -852,21 +824,6 @@ func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *
 	return srv.postHandshakeChecks(peers, inboundCount, c)
 }
 
-func (srv *Server) maxInboundConns() int {
-	return srv.MaxPeers - srv.maxDialedConns()
-}
-
-func (srv *Server) maxDialedConns() int {
-	if srv.NoDiscovery || srv.NoDial {
-		return 0
-	}
-	r := srv.DialRatio
-	if r == 0 {
-		r = defaultDialRatio
-	}
-	return srv.MaxPeers / r
-}
-
 // listenLoop runs in its own goroutine and accepts
 // inbound connections.
 func (srv *Server) listenLoop() {
@@ -935,18 +892,20 @@ func (srv *Server) listenLoop() {
 }
 
 func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error {
-	if remoteIP != nil {
-		// Reject connections that do not match NetRestrict.
-		if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) {
-			return fmt.Errorf("not whitelisted in NetRestrict")
-		}
-		// Reject Internet peers that try too often.
-		srv.inboundHistory.expire(time.Now())
-		if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
-			return fmt.Errorf("too many attempts")
-		}
-		srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime))
+	if remoteIP == nil {
+		return nil
+	}
+	// Reject connections that do not match NetRestrict.
+	if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) {
+		return fmt.Errorf("not whitelisted in NetRestrict")
 	}
+	// Reject Internet peers that try too often.
+	now := srv.clock.Now()
+	srv.inboundHistory.expire(now, nil)
+	if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
+		return fmt.Errorf("too many attempts")
+	}
+	srv.inboundHistory.add(remoteIP.String(), now.Add(inboundThrottleTime))
 	return nil
 }
 
@@ -958,7 +917,6 @@ func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node)
 	err := srv.setupConn(c, flags, dialDest)
 	if err != nil {
 		c.close(err)
-		srv.log.Trace("Setting up connection failed", "addr", fd.RemoteAddr(), "err", err)
 	}
 	return err
 }
@@ -977,7 +935,9 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
 	if dialDest != nil {
 		dialPubkey = new(ecdsa.PublicKey)
 		if err := dialDest.Load((*enode.Secp256k1)(dialPubkey)); err != nil {
-			return errors.New("dial destination doesn't have a secp256k1 public key")
+			err = errors.New("dial destination doesn't have a secp256k1 public key")
+			srv.log.Trace("Setting up connection failed", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
+			return err
 		}
 	}
 
@@ -1006,7 +966,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
 	// Run the capability negotiation handshake.
 	phs, err := c.doProtoHandshake(srv.ourHandshake)
 	if err != nil {
-		clog.Trace("Failed proto handshake", "err", err)
+		clog.Trace("Failed p2p handshake", "err", err)
 		return err
 	}
 	if id := c.node.ID(); !bytes.Equal(crypto.Keccak256(phs.ID), id[:]) {
@@ -1020,9 +980,6 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro
 		return err
 	}
 
-	// If the checks completed successfully, the connection has been added as a peer and
-	// runPeer has been launched.
-	clog.Trace("Connection set up", "inbound", dialDest == nil)
 	return nil
 }
 
@@ -1054,15 +1011,22 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
 	return <-c.cont
 }
 
+func (srv *Server) launchPeer(c *conn) *Peer {
+	p := newPeer(srv.log, c, srv.Protocols)
+	if srv.EnableMsgEvents {
+		// If message events are enabled, pass the peerFeed
+		// to the peer.
+		p.events = &srv.peerFeed
+	}
+	go srv.runPeer(p)
+	return p
+}
+
 // runPeer runs in its own goroutine for each peer.
-// it waits until the Peer logic returns and removes
-// the peer.
 func (srv *Server) runPeer(p *Peer) {
 	if srv.newPeerHook != nil {
 		srv.newPeerHook(p)
 	}
-
-	// broadcast peer add
 	srv.peerFeed.Send(&PeerEvent{
 		Type:          PeerEventTypeAdd,
 		Peer:          p.ID(),
@@ -1070,10 +1034,18 @@ func (srv *Server) runPeer(p *Peer) {
 		LocalAddress:  p.LocalAddr().String(),
 	})
 
-	// run the protocol
+	// Run the per-peer main loop.
 	remoteRequested, err := p.run()
 
-	// broadcast peer drop
+	// Announce disconnect on the main loop to update the peer set.
+	// The main loop waits for existing peers to be sent on srv.delpeer
+	// before returning, so this send should not select on srv.quit.
+	srv.delpeer <- peerDrop{p, err, remoteRequested}
+
+	// Broadcast peer drop to external subscribers. This needs to be
+	// after the send to delpeer so subscribers have a consistent view of
+	// the peer set (i.e. Server.Peers() doesn't include the peer when the
+	// event is received.
 	srv.peerFeed.Send(&PeerEvent{
 		Type:          PeerEventTypeDrop,
 		Peer:          p.ID(),
@@ -1081,10 +1053,6 @@ func (srv *Server) runPeer(p *Peer) {
 		RemoteAddress: p.RemoteAddr().String(),
 		LocalAddress:  p.LocalAddr().String(),
 	})
-
-	// Note: run waits for existing peers to be sent on srv.delpeer
-	// before returning, so this send should not select on srv.quit.
-	srv.delpeer <- peerDrop{p, err, remoteRequested}
 }
 
 // NodeInfo represents a short summary of the information known about the host.

+ 59 - 159
p2p/server_test.go

@@ -34,10 +34,6 @@ import (
 	"golang.org/x/crypto/sha3"
 )
 
-// func init() {
-// 	log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
-// }
-
 type testTransport struct {
 	rpub *ecdsa.PublicKey
 	*rlpx
@@ -72,11 +68,12 @@ func (c *testTransport) close(err error) {
 
 func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *Server {
 	config := Config{
-		Name:       "test",
-		MaxPeers:   10,
-		ListenAddr: "127.0.0.1:0",
-		PrivateKey: newkey(),
-		Logger:     testlog.Logger(t, log.LvlTrace),
+		Name:        "test",
+		MaxPeers:    10,
+		ListenAddr:  "127.0.0.1:0",
+		NoDiscovery: true,
+		PrivateKey:  newkey(),
+		Logger:      testlog.Logger(t, log.LvlTrace),
 	}
 	server := &Server{
 		Config:       config,
@@ -131,11 +128,10 @@ func TestServerDial(t *testing.T) {
 		t.Fatalf("could not setup listener: %v", err)
 	}
 	defer listener.Close()
-	accepted := make(chan net.Conn)
+	accepted := make(chan net.Conn, 1)
 	go func() {
 		conn, err := listener.Accept()
 		if err != nil {
-			t.Error("accept error:", err)
 			return
 		}
 		accepted <- conn
@@ -205,155 +201,38 @@ func TestServerDial(t *testing.T) {
 	}
 }
 
-// This test checks that tasks generated by dialstate are
-// actually executed and taskdone is called for them.
-func TestServerTaskScheduling(t *testing.T) {
-	var (
-		done           = make(chan *testTask)
-		quit, returned = make(chan struct{}), make(chan struct{})
-		tc             = 0
-		tg             = taskgen{
-			newFunc: func(running int, peers map[enode.ID]*Peer) []task {
-				tc++
-				return []task{&testTask{index: tc - 1}}
-			},
-			doneFunc: func(t task) {
-				select {
-				case done <- t.(*testTask):
-				case <-quit:
-				}
-			},
-		}
-	)
+// This test checks that RemovePeer disconnects the peer if it is connected.
+func TestServerRemovePeerDisconnect(t *testing.T) {
+	srv1 := &Server{Config: Config{
+		PrivateKey:  newkey(),
+		MaxPeers:    1,
+		NoDiscovery: true,
+		Logger:      testlog.Logger(t, log.LvlTrace).New("server", "1"),
+	}}
+	srv2 := &Server{Config: Config{
+		PrivateKey:  newkey(),
+		MaxPeers:    1,
+		NoDiscovery: true,
+		NoDial:      true,
+		ListenAddr:  "127.0.0.1:0",
+		Logger:      testlog.Logger(t, log.LvlTrace).New("server", "2"),
+	}}
+	srv1.Start()
+	defer srv1.Stop()
+	srv2.Start()
+	defer srv2.Stop()
 
-	// The Server in this test isn't actually running
-	// because we're only interested in what run does.
-	db, _ := enode.OpenDB("")
-	srv := &Server{
-		Config:    Config{MaxPeers: 10},
-		localnode: enode.NewLocalNode(db, newkey()),
-		nodedb:    db,
-		discmix:   enode.NewFairMix(0),
-		quit:      make(chan struct{}),
-		running:   true,
-		log:       log.New(),
-	}
-	srv.loopWG.Add(1)
-	go func() {
-		srv.run(tg)
-		close(returned)
-	}()
-
-	var gotdone []*testTask
-	for i := 0; i < 100; i++ {
-		gotdone = append(gotdone, <-done)
+	if !syncAddPeer(srv1, srv2.Self()) {
+		t.Fatal("peer not connected")
 	}
-	for i, task := range gotdone {
-		if task.index != i {
-			t.Errorf("task %d has wrong index, got %d", i, task.index)
-			break
-		}
-		if !task.called {
-			t.Errorf("task %d was not called", i)
-			break
-		}
-	}
-
-	close(quit)
-	srv.Stop()
-	select {
-	case <-returned:
-	case <-time.After(500 * time.Millisecond):
-		t.Error("Server.run did not return within 500ms")
+	srv1.RemovePeer(srv2.Self())
+	if srv1.PeerCount() > 0 {
+		t.Fatal("removed peer still connected")
 	}
 }
 
-// This test checks that Server doesn't drop tasks,
-// even if newTasks returns more than the maximum number of tasks.
-func TestServerManyTasks(t *testing.T) {
-	alltasks := make([]task, 300)
-	for i := range alltasks {
-		alltasks[i] = &testTask{index: i}
-	}
-
-	var (
-		db, _ = enode.OpenDB("")
-		srv   = &Server{
-			quit:      make(chan struct{}),
-			localnode: enode.NewLocalNode(db, newkey()),
-			nodedb:    db,
-			running:   true,
-			log:       log.New(),
-			discmix:   enode.NewFairMix(0),
-		}
-		done       = make(chan *testTask)
-		start, end = 0, 0
-	)
-	defer srv.Stop()
-	srv.loopWG.Add(1)
-	go srv.run(taskgen{
-		newFunc: func(running int, peers map[enode.ID]*Peer) []task {
-			start, end = end, end+maxActiveDialTasks+10
-			if end > len(alltasks) {
-				end = len(alltasks)
-			}
-			return alltasks[start:end]
-		},
-		doneFunc: func(tt task) {
-			done <- tt.(*testTask)
-		},
-	})
-
-	doneset := make(map[int]bool)
-	timeout := time.After(2 * time.Second)
-	for len(doneset) < len(alltasks) {
-		select {
-		case tt := <-done:
-			if doneset[tt.index] {
-				t.Errorf("task %d got done more than once", tt.index)
-			} else {
-				doneset[tt.index] = true
-			}
-		case <-timeout:
-			t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
-			for i := 0; i < len(alltasks); i++ {
-				if !doneset[i] {
-					t.Logf("task %d not done", i)
-				}
-			}
-			return
-		}
-	}
-}
-
-type taskgen struct {
-	newFunc  func(running int, peers map[enode.ID]*Peer) []task
-	doneFunc func(task)
-}
-
-func (tg taskgen) newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task {
-	return tg.newFunc(running, peers)
-}
-func (tg taskgen) taskDone(t task, now time.Time) {
-	tg.doneFunc(t)
-}
-func (tg taskgen) addStatic(*enode.Node) {
-}
-func (tg taskgen) removeStatic(*enode.Node) {
-}
-
-type testTask struct {
-	index  int
-	called bool
-}
-
-func (t *testTask) Do(srv *Server) {
-	t.called = true
-}
-
-// This test checks that connections are disconnected
-// just after the encryption handshake when the server is
-// at capacity. Trusted connections should still be accepted.
+// This test checks that connections are disconnected just after the encryption handshake
+// when the server is at capacity. Trusted connections should still be accepted.
 func TestServerAtCap(t *testing.T) {
 	trustedNode := newkey()
 	trustedID := enode.PubkeyToIDV4(&trustedNode.PublicKey)
@@ -363,7 +242,8 @@ func TestServerAtCap(t *testing.T) {
 			MaxPeers:     10,
 			NoDial:       true,
 			NoDiscovery:  true,
-			TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
+			TrustedNodes: []*enode.Node{newNode(trustedID, "")},
+			Logger:       testlog.Logger(t, log.LvlTrace),
 		},
 	}
 	if err := srv.Start(); err != nil {
@@ -401,14 +281,14 @@ func TestServerAtCap(t *testing.T) {
 	}
 
 	// Remove from trusted set and try again
-	srv.RemoveTrustedPeer(newNode(trustedID, nil))
+	srv.RemoveTrustedPeer(newNode(trustedID, ""))
 	c = newconn(trustedID)
 	if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
 		t.Error("wrong error for insert:", err)
 	}
 
 	// Add anotherID to trusted set and try again
-	srv.AddTrustedPeer(newNode(anotherID, nil))
+	srv.AddTrustedPeer(newNode(anotherID, ""))
 	c = newconn(anotherID)
 	if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
 		t.Error("unexpected error for trusted conn @posthandshake:", err)
@@ -439,9 +319,9 @@ func TestServerPeerLimits(t *testing.T) {
 			NoDial:      true,
 			NoDiscovery: true,
 			Protocols:   []Protocol{discard},
+			Logger:      testlog.Logger(t, log.LvlTrace),
 		},
 		newTransport: func(fd net.Conn) transport { return tp },
-		log:          log.New(),
 	}
 	if err := srv.Start(); err != nil {
 		t.Fatalf("couldn't start server: %v", err)
@@ -724,3 +604,23 @@ func (l *fakeAddrListener) Accept() (net.Conn, error) {
 func (c *fakeAddrConn) RemoteAddr() net.Addr {
 	return c.remoteAddr
 }
+
+func syncAddPeer(srv *Server, node *enode.Node) bool {
+	var (
+		ch      = make(chan *PeerEvent)
+		sub     = srv.SubscribeEvents(ch)
+		timeout = time.After(2 * time.Second)
+	)
+	defer sub.Unsubscribe()
+	srv.AddPeer(node)
+	for {
+		select {
+		case ev := <-ch:
+			if ev.Type == PeerEventTypeAdd && ev.Peer == node.ID() {
+				return true
+			}
+		case <-timeout:
+			return false
+		}
+	}
+}

+ 2 - 1
p2p/simulations/adapters/inproc.go

@@ -17,6 +17,7 @@
 package adapters
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"math"
@@ -126,7 +127,7 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) {
 
 // Dial implements the p2p.NodeDialer interface by connecting to the node using
 // an in-memory net.Pipe
-func (s *SimAdapter) Dial(dest *enode.Node) (conn net.Conn, err error) {
+func (s *SimAdapter) Dial(ctx context.Context, dest *enode.Node) (conn net.Conn, err error) {
 	node, ok := s.GetNode(dest.ID())
 	if !ok {
 		return nil, fmt.Errorf("unknown node: %s", dest.ID())

+ 12 - 8
p2p/util.go

@@ -18,7 +18,8 @@ package p2p
 
 import (
 	"container/heap"
-	"time"
+
+	"github.com/ethereum/go-ethereum/common/mclock"
 )
 
 // expHeap tracks strings and their expiry time.
@@ -27,16 +28,16 @@ type expHeap []expItem
 // expItem is an entry in addrHistory.
 type expItem struct {
 	item string
-	exp  time.Time
+	exp  mclock.AbsTime
 }
 
 // nextExpiry returns the next expiry time.
-func (h *expHeap) nextExpiry() time.Time {
+func (h *expHeap) nextExpiry() mclock.AbsTime {
 	return (*h)[0].exp
 }
 
 // add adds an item and sets its expiry time.
-func (h *expHeap) add(item string, exp time.Time) {
+func (h *expHeap) add(item string, exp mclock.AbsTime) {
 	heap.Push(h, expItem{item, exp})
 }
 
@@ -51,15 +52,18 @@ func (h expHeap) contains(item string) bool {
 }
 
 // expire removes items with expiry time before 'now'.
-func (h *expHeap) expire(now time.Time) {
-	for h.Len() > 0 && h.nextExpiry().Before(now) {
-		heap.Pop(h)
+func (h *expHeap) expire(now mclock.AbsTime, onExp func(string)) {
+	for h.Len() > 0 && h.nextExpiry() < now {
+		item := heap.Pop(h)
+		if onExp != nil {
+			onExp(item.(expItem).item)
+		}
 	}
 }
 
 // heap.Interface boilerplate
 func (h expHeap) Len() int            { return len(h) }
-func (h expHeap) Less(i, j int) bool  { return h[i].exp.Before(h[j].exp) }
+func (h expHeap) Less(i, j int) bool  { return h[i].exp < h[j].exp }
 func (h expHeap) Swap(i, j int)       { h[i], h[j] = h[j], h[i] }
 func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) }
 func (h *expHeap) Pop() interface{} {

+ 7 - 5
p2p/util_test.go

@@ -19,30 +19,32 @@ package p2p
 import (
 	"testing"
 	"time"
+
+	"github.com/ethereum/go-ethereum/common/mclock"
 )
 
 func TestExpHeap(t *testing.T) {
 	var h expHeap
 
 	var (
-		basetime = time.Unix(4000, 0)
+		basetime = mclock.AbsTime(10)
 		exptimeA = basetime.Add(2 * time.Second)
 		exptimeB = basetime.Add(3 * time.Second)
 		exptimeC = basetime.Add(4 * time.Second)
 	)
-	h.add("a", exptimeA)
 	h.add("b", exptimeB)
+	h.add("a", exptimeA)
 	h.add("c", exptimeC)
 
-	if !h.nextExpiry().Equal(exptimeA) {
+	if h.nextExpiry() != exptimeA {
 		t.Fatal("wrong nextExpiry")
 	}
 	if !h.contains("a") || !h.contains("b") || !h.contains("c") {
 		t.Fatal("heap doesn't contain all live items")
 	}
 
-	h.expire(exptimeA.Add(1))
-	if !h.nextExpiry().Equal(exptimeB) {
+	h.expire(exptimeA.Add(1), nil)
+	if h.nextExpiry() != exptimeB {
 		t.Fatal("wrong nextExpiry")
 	}
 	if h.contains("a") {