|
|
@@ -20,6 +20,7 @@ import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"sync"
|
|
|
"time"
|
|
|
@@ -705,8 +706,11 @@ func (net *Network) snapshot(addServices []string, removeServices []string) (*Sn
|
|
|
return snap, nil
|
|
|
}
|
|
|
|
|
|
+var snapshotLoadTimeout = 120 * time.Second
|
|
|
+
|
|
|
// Load loads a network snapshot
|
|
|
func (net *Network) Load(snap *Snapshot) error {
|
|
|
+ // Start nodes.
|
|
|
for _, n := range snap.Nodes {
|
|
|
if _, err := net.NewNodeWithConfig(n.Node.Config); err != nil {
|
|
|
return err
|
|
|
@@ -718,6 +722,69 @@ func (net *Network) Load(snap *Snapshot) error {
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ // Prepare connection events counter.
|
|
|
+ allConnected := make(chan struct{}) // closed when all connections are established
|
|
|
+ done := make(chan struct{}) // ensures that the event loop goroutine is terminated
|
|
|
+ defer close(done)
|
|
|
+
|
|
|
+ // Subscribe to event channel.
|
|
|
+ // It needs to be done outside of the event loop goroutine (created below)
|
|
|
+ // to ensure that the event channel is blocking before connect calls are made.
|
|
|
+ events := make(chan *Event)
|
|
|
+ sub := net.Events().Subscribe(events)
|
|
|
+ defer sub.Unsubscribe()
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ // Expected number of connections.
|
|
|
+ total := len(snap.Conns)
|
|
|
+ // Set of all established connections from the snapshot, not other connections.
|
|
|
+ // Key array element 0 is the connection One field value, and element 1 connection Other field.
|
|
|
+ connections := make(map[[2]enode.ID]struct{}, total)
|
|
|
+
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case e := <-events:
|
|
|
+ // Ignore control events as they do not represent
|
|
|
+ // connect or disconnect (Up) state change.
|
|
|
+ if e.Control {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ // Detect only connection events.
|
|
|
+ if e.Type != EventTypeConn {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ connection := [2]enode.ID{e.Conn.One, e.Conn.Other}
|
|
|
+ // Nodes are still not connected or have been disconnected.
|
|
|
+ if !e.Conn.Up {
|
|
|
+ // Delete the connection from the set of established connections.
|
|
|
+ // This will prevent false positive in case disconnections happen.
|
|
|
+ delete(connections, connection)
|
|
|
+ log.Warn("load snapshot: unexpected disconnection", "one", e.Conn.One, "other", e.Conn.Other)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ // Check that the connection is from the snapshot.
|
|
|
+ for _, conn := range snap.Conns {
|
|
|
+ if conn.One == e.Conn.One && conn.Other == e.Conn.Other {
|
|
|
+ // Add the connection to the set of established connections.
|
|
|
+ connections[connection] = struct{}{}
|
|
|
+ if len(connections) == total {
|
|
|
+ // Signal that all nodes are connected.
|
|
|
+ close(allConnected)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case <-done:
|
|
|
+ // Load function returned, terminate this goroutine.
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Start connecting.
|
|
|
for _, conn := range snap.Conns {
|
|
|
|
|
|
if !net.GetNode(conn.One).Up || !net.GetNode(conn.Other).Up {
|
|
|
@@ -729,6 +796,14 @@ func (net *Network) Load(snap *Snapshot) error {
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ select {
|
|
|
+ // Wait until all connections from the snapshot are established.
|
|
|
+ case <-allConnected:
|
|
|
+ // Make sure that we do not wait forever.
|
|
|
+ case <-time.After(snapshotLoadTimeout):
|
|
|
+ return errors.New("snapshot connections not established")
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|