Explorar el Código

p2p, p2p/discover: track bootstrap state in p2p/discover

This change simplifies the dial scheduling logic because it
no longer needs to track whether the discovery table has been
bootstrapped.
Felix Lange hace 10 años
padre
commit
04c6369a09
Se han modificado 6 ficheros con 110 adiciones y 91 borrados
  1. 7 26
      p2p/dial.go
  2. 8 20
      p2p/dial_test.go
  3. 19 1
      p2p/discover/node.go
  4. 68 37
      p2p/discover/table.go
  5. 5 7
      p2p/discover/udp.go
  6. 3 0
      p2p/server.go

+ 7 - 26
p2p/dial.go

@@ -46,7 +46,6 @@ type dialstate struct {
 	ntab        discoverTable
 
 	lookupRunning bool
-	bootstrapped  bool
 
 	dialing     map[discover.NodeID]connFlag
 	lookupBuf   []*discover.Node // current discovery lookup results
@@ -58,7 +57,6 @@ type dialstate struct {
 type discoverTable interface {
 	Self() *discover.Node
 	Close()
-	Bootstrap([]*discover.Node)
 	Lookup(target discover.NodeID) []*discover.Node
 	ReadRandomNodes([]*discover.Node) int
 }
@@ -84,13 +82,9 @@ type dialTask struct {
 
 // discoverTask runs discovery table operations.
 // Only one discoverTask is active at any time.
-//
-// If bootstrap is true, the task runs Table.Bootstrap,
-// otherwise it performs a random lookup and leaves the
-// results in the task.
+// discoverTask.Do performs a random lookup.
 type discoverTask struct {
-	bootstrap bool
-	results   []*discover.Node
+	results []*discover.Node
 }
 
 // A waitExpireTask is generated if there are no other tasks
@@ -154,7 +148,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
 	// Use random nodes from the table for half of the necessary
 	// dynamic dials.
 	randomCandidates := needDynDials / 2
-	if randomCandidates > 0 && s.bootstrapped {
+	if randomCandidates > 0 {
 		n := s.ntab.ReadRandomNodes(s.randomNodes)
 		for i := 0; i < randomCandidates && i < n; i++ {
 			if addDial(dynDialedConn, s.randomNodes[i]) {
@@ -171,12 +165,10 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
 		}
 	}
 	s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
-	// Launch a discovery lookup if more candidates are needed. The
-	// first discoverTask bootstraps the table and won't return any
-	// results.
+	// Launch a discovery lookup if more candidates are needed.
 	if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
 		s.lookupRunning = true
-		newtasks = append(newtasks, &discoverTask{bootstrap: !s.bootstrapped})
+		newtasks = append(newtasks, &discoverTask{})
 	}
 
 	// Launch a timer to wait for the next node to expire if all
@@ -196,9 +188,6 @@ func (s *dialstate) taskDone(t task, now time.Time) {
 		s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration))
 		delete(s.dialing, t.dest.ID)
 	case *discoverTask:
-		if t.bootstrap {
-			s.bootstrapped = true
-		}
 		s.lookupRunning = false
 		s.lookupBuf = append(s.lookupBuf, t.results...)
 	}
@@ -221,10 +210,6 @@ func (t *dialTask) String() string {
 }
 
 func (t *discoverTask) Do(srv *Server) {
-	if t.bootstrap {
-		srv.ntab.Bootstrap(srv.BootstrapNodes)
-		return
-	}
 	// newTasks generates a lookup task whenever dynamic dials are
 	// necessary. Lookups need to take some time, otherwise the
 	// event loop spins too fast.
@@ -238,12 +223,8 @@ func (t *discoverTask) Do(srv *Server) {
 	t.results = srv.ntab.Lookup(target)
 }
 
-func (t *discoverTask) String() (s string) {
-	if t.bootstrap {
-		s = "discovery bootstrap"
-	} else {
-		s = "discovery lookup"
-	}
+func (t *discoverTask) String() string {
+	s := "discovery lookup"
 	if len(t.results) > 0 {
 		s += fmt.Sprintf(" (%d results)", len(t.results))
 	}

+ 8 - 20
p2p/dial_test.go

@@ -76,15 +76,10 @@ func runDialTest(t *testing.T, test dialtest) {
 
 type fakeTable []*discover.Node
 
-func (t fakeTable) Self() *discover.Node       { return new(discover.Node) }
-func (t fakeTable) Close()                     {}
-func (t fakeTable) Bootstrap([]*discover.Node) {}
-func (t fakeTable) Lookup(target discover.NodeID) []*discover.Node {
-	return nil
-}
-func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int {
-	return copy(buf, t)
-}
+func (t fakeTable) Self() *discover.Node                     { return new(discover.Node) }
+func (t fakeTable) Close()                                   {}
+func (t fakeTable) Lookup(discover.NodeID) []*discover.Node  { return nil }
+func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf, t) }
 
 // This test checks that dynamic dials are launched from discovery results.
 func TestDialStateDynDial(t *testing.T) {
@@ -98,7 +93,7 @@ func TestDialStateDynDial(t *testing.T) {
 					{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
 					{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
 				},
-				new: []task{&discoverTask{bootstrap: true}},
+				new: []task{&discoverTask{}},
 			},
 			// Dynamic dials are launched when it completes.
 			{
@@ -108,7 +103,7 @@ func TestDialStateDynDial(t *testing.T) {
 					{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
 				},
 				done: []task{
-					&discoverTask{bootstrap: true, results: []*discover.Node{
+					&discoverTask{results: []*discover.Node{
 						{ID: uintID(2)}, // this one is already connected and not dialed.
 						{ID: uintID(3)},
 						{ID: uintID(4)},
@@ -238,22 +233,15 @@ func TestDialStateDynDialFromTable(t *testing.T) {
 	runDialTest(t, dialtest{
 		init: newDialState(nil, table, 10),
 		rounds: []round{
-			// Discovery bootstrap is launched.
-			{
-				new: []task{&discoverTask{bootstrap: true}},
-			},
 			// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
 			{
-				done: []task{
-					&discoverTask{bootstrap: true},
-				},
 				new: []task{
 					&dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
 					&dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
 					&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
 					&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
 					&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
-					&discoverTask{bootstrap: false},
+					&discoverTask{},
 				},
 			},
 			// Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
@@ -275,7 +263,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
 					&dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
 					&dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
 					&dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
-					&discoverTask{bootstrap: false},
+					&discoverTask{},
 				},
 			},
 			// Dialing nodes 3,4,5 fails. The dials from the lookup succeed.

+ 19 - 1
p2p/discover/node.go

@@ -80,6 +80,24 @@ func (n *Node) Incomplete() bool {
 	return n.IP == nil
 }
 
+// checks whether n is a valid complete node.
+func (n *Node) validateComplete() error {
+	if n.Incomplete() {
+		return errors.New("incomplete node")
+	}
+	if n.UDP == 0 {
+		return errors.New("missing UDP port")
+	}
+	if n.TCP == 0 {
+		return errors.New("missing TCP port")
+	}
+	if n.IP.IsMulticast() || n.IP.IsUnspecified() {
+		return errors.New("invalid IP (multicast/unspecified)")
+	}
+	_, err := n.ID.Pubkey() // validate the key (on curve, etc.)
+	return err
+}
+
 // The string representation of a Node is a URL.
 // Please see ParseNode for a description of the format.
 func (n *Node) String() string {
@@ -249,7 +267,7 @@ func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
 	p.X.SetBytes(id[:half])
 	p.Y.SetBytes(id[half:])
 	if !p.Curve.IsOnCurve(p.X, p.Y) {
-		return nil, errors.New("not a point on the S256 curve")
+		return nil, errors.New("id is invalid secp256k1 curve point")
 	}
 	return p, nil
 }

+ 68 - 37
p2p/discover/table.go

@@ -25,6 +25,7 @@ package discover
 import (
 	"crypto/rand"
 	"encoding/binary"
+	"fmt"
 	"net"
 	"sort"
 	"sync"
@@ -56,7 +57,7 @@ type Table struct {
 	nursery []*Node           // bootstrap nodes
 	db      *nodeDB           // database of known nodes
 
-	refreshReq chan struct{}
+	refreshReq chan chan struct{}
 	closeReq   chan struct{}
 	closed     chan struct{}
 
@@ -102,7 +103,7 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
 		self:       NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)),
 		bonding:    make(map[NodeID]*bondproc),
 		bondslots:  make(chan struct{}, maxBondingPingPongs),
-		refreshReq: make(chan struct{}),
+		refreshReq: make(chan chan struct{}),
 		closeReq:   make(chan struct{}),
 		closed:     make(chan struct{}),
 	}
@@ -179,21 +180,27 @@ func (tab *Table) Close() {
 	}
 }
 
-// Bootstrap sets the bootstrap nodes. These nodes are used to connect
-// to the network if the table is empty. Bootstrap will also attempt to
-// fill the table by performing random lookup operations on the
-// network.
-func (tab *Table) Bootstrap(nodes []*Node) {
+// SetFallbackNodes sets the initial points of contact. These nodes
+// are used to connect to the network if the table is empty and there
+// are no known nodes in the database.
+func (tab *Table) SetFallbackNodes(nodes []*Node) error {
+	for _, n := range nodes {
+		if err := n.validateComplete(); err != nil {
+			return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
+		}
+	}
 	tab.mutex.Lock()
-	// TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
 	tab.nursery = make([]*Node, 0, len(nodes))
 	for _, n := range nodes {
 		cpy := *n
+		// Recompute cpy.sha because the node might not have been
+		// created by NewNode or ParseNode.
 		cpy.sha = crypto.Sha3Hash(n.ID[:])
 		tab.nursery = append(tab.nursery, &cpy)
 	}
 	tab.mutex.Unlock()
-	tab.requestRefresh()
+	tab.refresh()
+	return nil
 }
 
 // Resolve searches for a specific node with the given ID.
@@ -224,26 +231,36 @@ func (tab *Table) Resolve(targetID NodeID) *Node {
 // The given target does not need to be an actual node
 // identifier.
 func (tab *Table) Lookup(targetID NodeID) []*Node {
+	return tab.lookup(targetID, true)
+}
+
+func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
 	var (
 		target         = crypto.Sha3Hash(targetID[:])
 		asked          = make(map[NodeID]bool)
 		seen           = make(map[NodeID]bool)
 		reply          = make(chan []*Node, alpha)
 		pendingQueries = 0
+		result         *nodesByDistance
 	)
 	// don't query further if we hit ourself.
 	// unlikely to happen often in practice.
 	asked[tab.self.ID] = true
 
-	tab.mutex.Lock()
-	// generate initial result set
-	result := tab.closest(target, bucketSize)
-	tab.mutex.Unlock()
-
-	// If the result set is empty, all nodes were dropped, refresh.
-	if len(result.entries) == 0 {
-		tab.requestRefresh()
-		return nil
+	for {
+		tab.mutex.Lock()
+		// generate initial result set
+		result = tab.closest(target, bucketSize)
+		tab.mutex.Unlock()
+		if len(result.entries) > 0 || !refreshIfEmpty {
+			break
+		}
+		// The result set is empty, all nodes were dropped, refresh.
+		// We actually wait for the refresh to complete here. The very
+		// first query will hit this case and run the bootstrapping
+		// logic.
+		<-tab.refresh()
+		refreshIfEmpty = false
 	}
 
 	for {
@@ -287,24 +304,24 @@ func (tab *Table) Lookup(targetID NodeID) []*Node {
 	return result.entries
 }
 
-func (tab *Table) requestRefresh() {
+func (tab *Table) refresh() <-chan struct{} {
+	done := make(chan struct{})
 	select {
-	case tab.refreshReq <- struct{}{}:
+	case tab.refreshReq <- done:
 	case <-tab.closed:
+		close(done)
 	}
+	return done
 }
 
+// refreshLoop schedules doRefresh runs and coordinates shutdown.
 func (tab *Table) refreshLoop() {
-	defer func() {
-		tab.db.close()
-		if tab.net != nil {
-			tab.net.close()
-		}
-		close(tab.closed)
-	}()
-
-	timer := time.NewTicker(autoRefreshInterval)
-	var done chan struct{}
+	var (
+		timer   = time.NewTicker(autoRefreshInterval)
+		waiting []chan struct{} // accumulates waiting callers while doRefresh runs
+		done    chan struct{}   // where doRefresh reports completion
+	)
+loop:
 	for {
 		select {
 		case <-timer.C:
@@ -312,20 +329,34 @@ func (tab *Table) refreshLoop() {
 				done = make(chan struct{})
 				go tab.doRefresh(done)
 			}
-		case <-tab.refreshReq:
+		case req := <-tab.refreshReq:
+			waiting = append(waiting, req)
 			if done == nil {
 				done = make(chan struct{})
 				go tab.doRefresh(done)
 			}
 		case <-done:
+			for _, ch := range waiting {
+				close(ch)
+			}
+			waiting = nil
 			done = nil
 		case <-tab.closeReq:
-			if done != nil {
-				<-done
-			}
-			return
+			break loop
 		}
 	}
+
+	if tab.net != nil {
+		tab.net.close()
+	}
+	if done != nil {
+		<-done
+	}
+	for _, ch := range waiting {
+		close(ch)
+	}
+	tab.db.close()
+	close(tab.closed)
 }
 
 // doRefresh performs a lookup for a random target to keep buckets
@@ -342,7 +373,7 @@ func (tab *Table) doRefresh(done chan struct{}) {
 	// We perform a lookup with a random target instead.
 	var target NodeID
 	rand.Read(target[:])
-	result := tab.Lookup(target)
+	result := tab.lookup(target, false)
 	if len(result) > 0 {
 		return
 	}
@@ -366,7 +397,7 @@ func (tab *Table) doRefresh(done chan struct{}) {
 	tab.mutex.Unlock()
 
 	// Finally, do a self lookup to fill up the buckets.
-	tab.Lookup(tab.self.ID)
+	tab.lookup(tab.self.ID, false)
 }
 
 // closest returns the n nodes in the table that are closest to the

+ 5 - 7
p2p/discover/udp.go

@@ -114,13 +114,11 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
 	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
 }
 
-func nodeFromRPC(rn rpcNode) (n *Node, valid bool) {
+func nodeFromRPC(rn rpcNode) (*Node, error) {
 	// TODO: don't accept localhost, LAN addresses from internet hosts
-	// TODO: check public key is on secp256k1 curve
-	if rn.IP.IsMulticast() || rn.IP.IsUnspecified() || rn.UDP == 0 {
-		return nil, false
-	}
-	return NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP), true
+	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
+	err := n.validateComplete()
+	return n, err
 }
 
 func nodeToRPC(n *Node) rpcNode {
@@ -271,7 +269,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
 		reply := r.(*neighbors)
 		for _, rn := range reply.Nodes {
 			nreceived++
-			if n, valid := nodeFromRPC(rn); valid {
+			if n, err := nodeFromRPC(rn); err == nil {
 				nodes = append(nodes, n)
 			}
 		}

+ 3 - 0
p2p/server.go

@@ -334,6 +334,9 @@ func (srv *Server) Start() (err error) {
 		if err != nil {
 			return err
 		}
+		if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil {
+			return err
+		}
 		srv.ntab = ntab
 	}