Эх сурвалжийг харах

p2p/discover, p2p/discv5: prevent relay of invalid IPs and low ports

The discovery DHT contains a number of hosts with LAN and loopback IPs.
These get relayed because some implementations do not perform any checks
on the IP.

go-ethereum already prevented relay in most cases because it verifies
that the host actually exists before adding it to the local table. But
this verification causes other issues. We have received several reports
where people's VPSs got shut down by hosting providers because sending
packets to random LAN hosts is indistinguishable from a slow port scan.

The new check prevents sending random packets to LAN by discarding LAN
IPs sent by Internet hosts (and loopback IPs from LAN and Internet
hosts). The new check also blacklists almost all currently registered
special-purpose networks assigned by IANA to avoid inciting random
responses from services in the LAN.

As another precaution against abuse of the DHT, ports below 1024 are now
considered invalid.
Felix Lange 9 жил өмнө
parent
commit
a98d1d67d6

+ 1 - 0
p2p/discover/table_test.go

@@ -146,6 +146,7 @@ func fillBucket(tab *Table, ld int) (last *Node) {
 func nodeAtDistance(base common.Hash, ld int) (n *Node) {
 func nodeAtDistance(base common.Hash, ld int) (n *Node) {
 	n = new(Node)
 	n = new(Node)
 	n.sha = hashAtDistance(base, ld)
 	n.sha = hashAtDistance(base, ld)
+	n.IP = net.IP{10, 0, 2, byte(ld)}
 	copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
 	copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
 	return n
 	return n
 }
 }

+ 16 - 4
p2p/discover/udp.go

@@ -29,6 +29,7 @@ import (
 	"github.com/ethereum/go-ethereum/logger"
 	"github.com/ethereum/go-ethereum/logger"
 	"github.com/ethereum/go-ethereum/logger/glog"
 	"github.com/ethereum/go-ethereum/logger/glog"
 	"github.com/ethereum/go-ethereum/p2p/nat"
 	"github.com/ethereum/go-ethereum/p2p/nat"
+	"github.com/ethereum/go-ethereum/p2p/netutil"
 	"github.com/ethereum/go-ethereum/rlp"
 	"github.com/ethereum/go-ethereum/rlp"
 )
 )
 
 
@@ -126,8 +127,13 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
 	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
 	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
 }
 }
 
 
-func nodeFromRPC(rn rpcNode) (*Node, error) {
-	// TODO: don't accept localhost, LAN addresses from internet hosts
+func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
+	if rn.UDP <= 1024 {
+		return nil, errors.New("low port")
+	}
+	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
+		return nil, err
+	}
 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
 	err := n.validateComplete()
 	err := n.validateComplete()
 	return n, err
 	return n, err
@@ -281,9 +287,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
 		reply := r.(*neighbors)
 		reply := r.(*neighbors)
 		for _, rn := range reply.Nodes {
 		for _, rn := range reply.Nodes {
 			nreceived++
 			nreceived++
-			if n, err := nodeFromRPC(rn); err == nil {
-				nodes = append(nodes, n)
+			n, err := nodeFromRPC(toaddr, rn)
+			if err != nil {
+				glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err)
+				continue
 			}
 			}
+			nodes = append(nodes, n)
 		}
 		}
 		return nreceived >= bucketSize
 		return nreceived >= bucketSize
 	})
 	})
@@ -595,6 +604,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
 	// Send neighbors in chunks with at most maxNeighbors per packet
 	// Send neighbors in chunks with at most maxNeighbors per packet
 	// to stay below the 1280 byte limit.
 	// to stay below the 1280 byte limit.
 	for i, n := range closest {
 	for i, n := range closest {
+		if netutil.CheckRelayIP(from.IP, n.IP) != nil {
+			continue
+		}
 		p.Nodes = append(p.Nodes, nodeToRPC(n))
 		p.Nodes = append(p.Nodes, nodeToRPC(n))
 		if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
 		if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
 			t.send(from, neighborsPacket, p)
 			t.send(from, neighborsPacket, p)

+ 4 - 3
p2p/discover/udp_test.go

@@ -68,7 +68,7 @@ func newUDPTest(t *testing.T) *udpTest {
 		pipe:       newpipe(),
 		pipe:       newpipe(),
 		localkey:   newkey(),
 		localkey:   newkey(),
 		remotekey:  newkey(),
 		remotekey:  newkey(),
-		remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
+		remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
 	}
 	}
 	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "")
 	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "")
 	return test
 	return test
@@ -312,8 +312,9 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
 	// check that the sent neighbors are all returned by findnode
 	// check that the sent neighbors are all returned by findnode
 	select {
 	select {
 	case result := <-resultc:
 	case result := <-resultc:
-		if !reflect.DeepEqual(result, list) {
-			t.Errorf("neighbors mismatch:\n  got:  %v\n  want: %v", result, list)
+		want := append(list[:2], list[3:]...)
+		if !reflect.DeepEqual(result, want) {
+			t.Errorf("neighbors mismatch:\n  got:  %v\n  want: %v", result, want)
 		}
 		}
 	case err := <-errc:
 	case err := <-errc:
 		t.Errorf("findnode error: %v", err)
 		t.Errorf("findnode error: %v", err)

+ 12 - 7
p2p/discv5/net.go

@@ -45,6 +45,7 @@ const (
 	bucketRefreshInterval = 1 * time.Minute
 	bucketRefreshInterval = 1 * time.Minute
 	seedCount             = 30
 	seedCount             = 30
 	seedMaxAge            = 5 * 24 * time.Hour
 	seedMaxAge            = 5 * 24 * time.Hour
+	lowPort               = 1024
 )
 )
 
 
 const testTopic = "foo"
 const testTopic = "foo"
@@ -684,16 +685,19 @@ func (net *Network) internNodeFromDB(dbn *Node) *Node {
 	return n
 	return n
 }
 }
 
 
-func (net *Network) internNodeFromNeighbours(rn rpcNode) (n *Node, err error) {
+func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) {
 	if rn.ID == net.tab.self.ID {
 	if rn.ID == net.tab.self.ID {
 		return nil, errors.New("is self")
 		return nil, errors.New("is self")
 	}
 	}
+	if rn.UDP <= lowPort {
+		return nil, errors.New("low port")
+	}
 	n = net.nodes[rn.ID]
 	n = net.nodes[rn.ID]
 	if n == nil {
 	if n == nil {
 		// We haven't seen this node before.
 		// We haven't seen this node before.
-		n, err = nodeFromRPC(rn)
-		n.state = unknown
+		n, err = nodeFromRPC(sender, rn)
 		if err == nil {
 		if err == nil {
+			n.state = unknown
 			net.nodes[n.ID] = n
 			net.nodes[n.ID] = n
 		}
 		}
 		return n, err
 		return n, err
@@ -1095,7 +1099,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket)
 		net.conn.sendNeighbours(n, results)
 		net.conn.sendNeighbours(n, results)
 		return n.state, nil
 		return n.state, nil
 	case neighborsPacket:
 	case neighborsPacket:
-		err := net.handleNeighboursPacket(n, pkt.data.(*neighbors))
+		err := net.handleNeighboursPacket(n, pkt)
 		return n.state, err
 		return n.state, err
 	case neighboursTimeout:
 	case neighboursTimeout:
 		if n.pendingNeighbours != nil {
 		if n.pendingNeighbours != nil {
@@ -1182,17 +1186,18 @@ func rlpHash(x interface{}) (h common.Hash) {
 	return h
 	return h
 }
 }
 
 
-func (net *Network) handleNeighboursPacket(n *Node, req *neighbors) error {
+func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
 	if n.pendingNeighbours == nil {
 	if n.pendingNeighbours == nil {
 		return errNoQuery
 		return errNoQuery
 	}
 	}
 	net.abortTimedEvent(n, neighboursTimeout)
 	net.abortTimedEvent(n, neighboursTimeout)
 
 
+	req := pkt.data.(*neighbors)
 	nodes := make([]*Node, len(req.Nodes))
 	nodes := make([]*Node, len(req.Nodes))
 	for i, rn := range req.Nodes {
 	for i, rn := range req.Nodes {
-		nn, err := net.internNodeFromNeighbours(rn)
+		nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn)
 		if err != nil {
 		if err != nil {
-			glog.V(logger.Debug).Infof("invalid neighbour from %x: %v", n.ID[:8], err)
+			glog.V(logger.Debug).Infof("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err)
 			continue
 			continue
 		}
 		}
 		nodes[i] = nn
 		nodes[i] = nn

+ 15 - 12
p2p/discv5/net_test.go

@@ -40,7 +40,7 @@ func TestNetwork_Lookup(t *testing.T) {
 	// 	t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
 	// 	t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
 	// }
 	// }
 	// seed table with initial node (otherwise lookup will terminate immediately)
 	// seed table with initial node (otherwise lookup will terminate immediately)
-	seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 999)}
+	seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{10, 0, 2, 99}, lowPort+256, 999)}
 	if err := network.SetFallbackNodes(seeds); err != nil {
 	if err := network.SetFallbackNodes(seeds); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -272,13 +272,13 @@ func (tn *preminedTestnet) sendFindnode(to *Node, target NodeID) {
 func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) {
 func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) {
 	// current log distance is encoded in port number
 	// current log distance is encoded in port number
 	// fmt.Println("findnode query at dist", toaddr.Port)
 	// fmt.Println("findnode query at dist", toaddr.Port)
-	if to.UDP == 0 {
-		panic("query to node at distance 0")
+	if to.UDP <= lowPort {
+		panic("query to node at or below distance 0")
 	}
 	}
 	next := to.UDP - 1
 	next := to.UDP - 1
 	var result []rpcNode
 	var result []rpcNode
-	for i, id := range tn.dists[to.UDP] {
-		result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1)))
+	for i, id := range tn.dists[to.UDP-lowPort] {
+		result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort)))
 	}
 	}
 	injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
 	injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
 }
 }
@@ -296,14 +296,14 @@ func (tn *preminedTestnet) send(to *Node, ptype nodeEvent, data interface{}) (ha
 		// ignored
 		// ignored
 	case findnodeHashPacket:
 	case findnodeHashPacket:
 		// current log distance is encoded in port number
 		// current log distance is encoded in port number
-		// fmt.Println("findnode query at dist", toaddr.Port)
-		if to.UDP == 0 {
-			panic("query to node at distance 0")
+		// fmt.Println("findnode query at dist", toaddr.Port-lowPort)
+		if to.UDP <= lowPort {
+			panic("query to node at or below  distance 0")
 		}
 		}
 		next := to.UDP - 1
 		next := to.UDP - 1
 		var result []rpcNode
 		var result []rpcNode
-		for i, id := range tn.dists[to.UDP] {
-			result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1)))
+		for i, id := range tn.dists[to.UDP-lowPort] {
+			result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort)))
 		}
 		}
 		injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
 		injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
 	default:
 	default:
@@ -328,8 +328,11 @@ func (tn *preminedTestnet) sendTopicRegister(to *Node, topics []Topic, idx int,
 	panic("sendTopicRegister called")
 	panic("sendTopicRegister called")
 }
 }
 
 
-func (*preminedTestnet) Close()                  {}
-func (*preminedTestnet) localAddr() *net.UDPAddr { return new(net.UDPAddr) }
+func (*preminedTestnet) Close() {}
+
+func (*preminedTestnet) localAddr() *net.UDPAddr {
+	return &net.UDPAddr{IP: net.ParseIP("10.0.1.1"), Port: 40000}
+}
 
 
 // mine generates a testnet struct literal with nodes at
 // mine generates a testnet struct literal with nodes at
 // various distances to the given target.
 // various distances to the given target.

+ 8 - 2
p2p/discv5/udp.go

@@ -29,6 +29,7 @@ import (
 	"github.com/ethereum/go-ethereum/logger"
 	"github.com/ethereum/go-ethereum/logger"
 	"github.com/ethereum/go-ethereum/logger/glog"
 	"github.com/ethereum/go-ethereum/logger/glog"
 	"github.com/ethereum/go-ethereum/p2p/nat"
 	"github.com/ethereum/go-ethereum/p2p/nat"
+	"github.com/ethereum/go-ethereum/p2p/netutil"
 	"github.com/ethereum/go-ethereum/rlp"
 	"github.com/ethereum/go-ethereum/rlp"
 )
 )
 
 
@@ -198,8 +199,10 @@ func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
 	return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP)
 	return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP)
 }
 }
 
 
-func nodeFromRPC(rn rpcNode) (*Node, error) {
-	// TODO: don't accept localhost, LAN addresses from internet hosts
+func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
+	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
+		return nil, err
+	}
 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
 	err := n.validateComplete()
 	err := n.validateComplete()
 	return n, err
 	return n, err
@@ -327,6 +330,9 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
 		return
 		return
 	}
 	}
 	for i, result := range nodes {
 	for i, result := range nodes {
+		if netutil.CheckRelayIP(remote.IP, result.IP) != nil {
+			continue
+		}
 		p.Nodes = append(p.Nodes, nodeToRPC(result))
 		p.Nodes = append(p.Nodes, nodeToRPC(result))
 		if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 {
 		if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 {
 			t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
 			t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)