Bläddra i källkod

p2p, p2p/discover, p2p/discv5: add IP network restriction feature

The p2p packages can now be configured to restrict all communication to
a certain subset of IP networks. This feature is meant to be used for
private networks.
Felix Lange 9 år sedan
förälder
incheckning
a47341cf96
9 ändrade filer med 124 tillägg och 34 borttagningar
  1. 38 7
      p2p/dial.go
  2. 36 5
      p2p/dial_test.go
  3. 15 10
      p2p/discover/udp.go
  4. 1 1
      p2p/discover/udp_test.go
  5. 9 3
      p2p/discv5/net.go
  6. 1 1
      p2p/discv5/net_test.go
  7. 1 1
      p2p/discv5/sim_test.go
  8. 2 2
      p2p/discv5/udp.go
  9. 21 4
      p2p/server.go

+ 38 - 7
p2p/dial.go

@@ -19,6 +19,7 @@ package p2p
 import (
 	"container/heap"
 	"crypto/rand"
+	"errors"
 	"fmt"
 	"net"
 	"time"
@@ -26,6 +27,7 @@ import (
 	"github.com/ethereum/go-ethereum/logger"
 	"github.com/ethereum/go-ethereum/logger/glog"
 	"github.com/ethereum/go-ethereum/p2p/discover"
+	"github.com/ethereum/go-ethereum/p2p/netutil"
 )
 
 const (
@@ -48,6 +50,7 @@ const (
 type dialstate struct {
 	maxDynDials int
 	ntab        discoverTable
+	netrestrict *netutil.Netlist
 
 	lookupRunning bool
 	dialing       map[discover.NodeID]connFlag
@@ -100,10 +103,11 @@ type waitExpireTask struct {
 	time.Duration
 }
 
-func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate {
+func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
 	s := &dialstate{
 		maxDynDials: maxdyn,
 		ntab:        ntab,
+		netrestrict: netrestrict,
 		static:      make(map[discover.NodeID]*dialTask),
 		dialing:     make(map[discover.NodeID]connFlag),
 		randomNodes: make([]*discover.Node, maxdyn/2),
@@ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) {
 
 func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
 	var newtasks []task
-	isDialing := func(id discover.NodeID) bool {
-		_, found := s.dialing[id]
-		return found || peers[id] != nil || s.hist.contains(id)
-	}
 	addDial := func(flag connFlag, n *discover.Node) bool {
-		if isDialing(n.ID) {
+		if err := s.checkDial(n, peers); err != nil {
+			glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err)
 			return false
 		}
 		s.dialing[n.ID] = flag
@@ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
 
 	// Create dials for static nodes if they are not connected.
 	for id, t := range s.static {
-		if !isDialing(id) {
+		err := s.checkDial(t.dest, peers)
+		switch err {
+		case errNotWhitelisted, errSelf:
+			glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err)
+			delete(s.static, t.dest.ID)
+		case nil:
 			s.dialing[id] = t.flags
 			newtasks = append(newtasks, t)
 		}
@@ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
 	return newtasks
 }
 
+var (
+	errSelf             = errors.New("is self")
+	errAlreadyDialing   = errors.New("already dialing")
+	errAlreadyConnected = errors.New("already connected")
+	errRecentlyDialed   = errors.New("recently dialed")
+	errNotWhitelisted   = errors.New("not contained in netrestrict whitelist")
+)
+
+func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error {
+	_, dialing := s.dialing[n.ID]
+	switch {
+	case dialing:
+		return errAlreadyDialing
+	case peers[n.ID] != nil:
+		return errAlreadyConnected
+	case s.ntab != nil && n.ID == s.ntab.Self().ID:
+		return errSelf
+	case s.netrestrict != nil && !s.netrestrict.Contains(n.IP):
+		return errNotWhitelisted
+	case s.hist.contains(n.ID):
+		return errRecentlyDialed
+	}
+	return nil
+}
+
 func (s *dialstate) taskDone(t task, now time.Time) {
 	switch t := t.(type) {
 	case *dialTask:

+ 36 - 5
p2p/dial_test.go

@@ -25,6 +25,7 @@ import (
 
 	"github.com/davecgh/go-spew/spew"
 	"github.com/ethereum/go-ethereum/p2p/discover"
+	"github.com/ethereum/go-ethereum/p2p/netutil"
 )
 
 func init() {
@@ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf,
 // This test checks that dynamic dials are launched from discovery results.
 func TestDialStateDynDial(t *testing.T) {
 	runDialTest(t, dialtest{
-		init: newDialState(nil, fakeTable{}, 5),
+		init: newDialState(nil, fakeTable{}, 5, nil),
 		rounds: []round{
 			// A discovery query is launched.
 			{
@@ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
 	}
 
 	runDialTest(t, dialtest{
-		init: newDialState(nil, table, 10),
+		init: newDialState(nil, table, 10, nil),
 		rounds: []round{
 			// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
 			{
@@ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) {
 	})
 }
 
+// This test checks that candidates that do not match the netrestrict list are not dialed.
+func TestDialStateNetRestrict(t *testing.T) {
+	// This table always returns the same random nodes
+	// in the order given below.
+	table := fakeTable{
+		{ID: uintID(1), IP: net.ParseIP("127.0.0.1")},
+		{ID: uintID(2), IP: net.ParseIP("127.0.0.2")},
+		{ID: uintID(3), IP: net.ParseIP("127.0.0.3")},
+		{ID: uintID(4), IP: net.ParseIP("127.0.0.4")},
+		{ID: uintID(5), IP: net.ParseIP("127.0.2.5")},
+		{ID: uintID(6), IP: net.ParseIP("127.0.2.6")},
+		{ID: uintID(7), IP: net.ParseIP("127.0.2.7")},
+		{ID: uintID(8), IP: net.ParseIP("127.0.2.8")},
+	}
+	restrict := new(netutil.Netlist)
+	restrict.Add("127.0.2.0/24")
+
+	runDialTest(t, dialtest{
+		init: newDialState(nil, table, 10, restrict),
+		rounds: []round{
+			{
+				new: []task{
+					&dialTask{flags: dynDialedConn, dest: table[4]},
+					&discoverTask{},
+				},
+			},
+		},
+	})
+}
+
 // This test checks that static dials are launched.
 func TestDialStateStaticDial(t *testing.T) {
 	wantStatic := []*discover.Node{
@@ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) {
 	}
 
 	runDialTest(t, dialtest{
-		init: newDialState(wantStatic, fakeTable{}, 0),
+		init: newDialState(wantStatic, fakeTable{}, 0, nil),
 		rounds: []round{
 			// Static dials are launched for the nodes that
 			// aren't yet connected.
@@ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) {
 	}
 
 	runDialTest(t, dialtest{
-		init: newDialState(wantStatic, fakeTable{}, 0),
+		init: newDialState(wantStatic, fakeTable{}, 0, nil),
 		rounds: []round{
 			// Static dials are launched for the nodes that
 			// aren't yet connected.
@@ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) {
 func TestDialResolve(t *testing.T) {
 	resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444)
 	table := &resolveMock{answer: resolved}
-	state := newDialState(nil, table, 0)
+	state := newDialState(nil, table, 0, nil)
 
 	// Check that the task is generated with an incomplete ID.
 	dest := discover.NewNode(uintID(1), nil, 0, 0)

+ 15 - 10
p2p/discover/udp.go

@@ -127,13 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
 	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
 }
 
-func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
+func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
 	if rn.UDP <= 1024 {
 		return nil, errors.New("low port")
 	}
 	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
 		return nil, err
 	}
+	if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
+		return nil, errors.New("not contained in netrestrict whitelist")
+	}
 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
 	err := n.validateComplete()
 	return n, err
@@ -157,6 +160,7 @@ type conn interface {
 // udp implements the RPC protocol.
 type udp struct {
 	conn        conn
+	netrestrict *netutil.Netlist
 	priv        *ecdsa.PrivateKey
 	ourEndpoint rpcEndpoint
 
@@ -207,7 +211,7 @@ type reply struct {
 }
 
 // ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) {
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
 	addr, err := net.ResolveUDPAddr("udp", laddr)
 	if err != nil {
 		return nil, err
@@ -216,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
 	if err != nil {
 		return nil, err
 	}
-	tab, _, err := newUDP(priv, conn, natm, nodeDBPath)
+	tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict)
 	if err != nil {
 		return nil, err
 	}
@@ -224,13 +228,14 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
 	return tab, nil
 }
 
-func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) {
+func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
 	udp := &udp{
-		conn:       c,
-		priv:       priv,
-		closing:    make(chan struct{}),
-		gotreply:   make(chan reply),
-		addpending: make(chan *pending),
+		conn:        c,
+		priv:        priv,
+		netrestrict: netrestrict,
+		closing:     make(chan struct{}),
+		gotreply:    make(chan reply),
+		addpending:  make(chan *pending),
 	}
 	realaddr := c.LocalAddr().(*net.UDPAddr)
 	if natm != nil {
@@ -287,7 +292,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
 		reply := r.(*neighbors)
 		for _, rn := range reply.Nodes {
 			nreceived++
-			n, err := nodeFromRPC(toaddr, rn)
+			n, err := t.nodeFromRPC(toaddr, rn)
 			if err != nil {
 				glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err)
 				continue

+ 1 - 1
p2p/discover/udp_test.go

@@ -70,7 +70,7 @@ func newUDPTest(t *testing.T) *udpTest {
 		remotekey:  newkey(),
 		remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
 	}
-	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "")
+	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil)
 	return test
 }
 

+ 9 - 3
p2p/discv5/net.go

@@ -31,6 +31,7 @@ import (
 	"github.com/ethereum/go-ethereum/logger"
 	"github.com/ethereum/go-ethereum/logger/glog"
 	"github.com/ethereum/go-ethereum/p2p/nat"
+	"github.com/ethereum/go-ethereum/p2p/netutil"
 	"github.com/ethereum/go-ethereum/rlp"
 )
 
@@ -63,8 +64,9 @@ func debugLog(s string) {
 
 // Network manages the table and all protocol interaction.
 type Network struct {
-	db   *nodeDB // database of known nodes
-	conn transport
+	db          *nodeDB // database of known nodes
+	conn        transport
+	netrestrict *netutil.Netlist
 
 	closed           chan struct{}          // closed when loop is done
 	closeReq         chan struct{}          // 'request to close'
@@ -133,7 +135,7 @@ type timeoutEvent struct {
 	node *Node
 }
 
-func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) {
+func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
 	ourID := PubkeyID(&ourPubkey)
 
 	var db *nodeDB
@@ -148,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d
 	net := &Network{
 		db:               db,
 		conn:             conn,
+		netrestrict:      netrestrict,
 		tab:              tab,
 		topictab:         newTopicTable(db, tab.self),
 		ticketStore:      newTicketStore(),
@@ -696,6 +699,9 @@ func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n
 	if n == nil {
 		// We haven't seen this node before.
 		n, err = nodeFromRPC(sender, rn)
+		if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
+			return n, errors.New("not contained in netrestrict whitelist")
+		}
 		if err == nil {
 			n.state = unknown
 			net.nodes[n.ID] = n

+ 1 - 1
p2p/discv5/net_test.go

@@ -28,7 +28,7 @@ import (
 
 func TestNetwork_Lookup(t *testing.T) {
 	key, _ := crypto.GenerateKey()
-	network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "")
+	network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil)
 	if err != nil {
 		t.Fatal(err)
 	}

+ 1 - 1
p2p/discv5/sim_test.go

@@ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network {
 	addr := &net.UDPAddr{IP: ip, Port: 30303}
 
 	transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key}
-	net, err := newNetwork(transport, key.PublicKey, nil, "<no database>")
+	net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil)
 	if err != nil {
 		panic("cannot launch new node: " + err.Error())
 	}

+ 2 - 2
p2p/discv5/udp.go

@@ -238,12 +238,12 @@ type udp struct {
 }
 
 // ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) {
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
 	transport, err := listenUDP(priv, laddr)
 	if err != nil {
 		return nil, err
 	}
-	net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath)
+	net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict)
 	if err != nil {
 		return nil, err
 	}

+ 21 - 4
p2p/server.go

@@ -30,6 +30,7 @@ import (
 	"github.com/ethereum/go-ethereum/p2p/discover"
 	"github.com/ethereum/go-ethereum/p2p/discv5"
 	"github.com/ethereum/go-ethereum/p2p/nat"
+	"github.com/ethereum/go-ethereum/p2p/netutil"
 )
 
 const (
@@ -101,6 +102,11 @@ type Config struct {
 	// allowed to connect, even above the peer limit.
 	TrustedNodes []*discover.Node
 
+	// Connectivity can be restricted to certain IP networks.
+	// If this option is set to a non-nil value, only hosts which match one of the
+	// IP networks contained in the list are considered.
+	NetRestrict *netutil.Netlist
+
 	// NodeDatabase is the path to the database containing the previously seen
 	// live nodes in the network.
 	NodeDatabase string
@@ -356,7 +362,7 @@ func (srv *Server) Start() (err error) {
 
 	// node table
 	if srv.Discovery {
-		ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
+		ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict)
 		if err != nil {
 			return err
 		}
@@ -367,7 +373,7 @@ func (srv *Server) Start() (err error) {
 	}
 
 	if srv.DiscoveryV5 {
-		ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase)
+		ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase)
 		if err != nil {
 			return err
 		}
@@ -381,7 +387,7 @@ func (srv *Server) Start() (err error) {
 	if !srv.Discovery {
 		dynPeers = 0
 	}
-	dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers)
+	dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict)
 
 	// handshake
 	srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)}
@@ -634,8 +640,19 @@ func (srv *Server) listenLoop() {
 			}
 			break
 		}
+
+		// Reject connections that do not match NetRestrict.
+		if srv.NetRestrict != nil {
+			if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
+				glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr())
+				fd.Close()
+				slots <- struct{}{}
+				continue
+			}
+		}
+
 		fd = newMeteredConn(fd, true)
-		glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr())
+		glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr())
 
 		// Spawn the handler. It will give the slot back when the connection
 		// has been established.