浏览代码

p2p/nat: limit UPNP request concurrency (#21390)

This adds a lock around requests because some routers can't handle
concurrent requests. Requests are also rate-limited.
 
The Map function request a new mapping exactly when the map timeout
occurs instead of 5 minutes earlier. This should prevent duplicate mappings.
Felix Lange 5 年之前
父节点
当前提交
1d25039ff5
共有 2 个文件被更改,包括 64 次插入25 次删除
  1. 4 5
      p2p/nat/nat.go
  2. 60 20
      p2p/nat/natupnp.go

+ 4 - 5
p2p/nat/nat.go

@@ -91,15 +91,14 @@ func Parse(spec string) (Interface, error) {
 }
 
 const (
-	mapTimeout        = 20 * time.Minute
-	mapUpdateInterval = 15 * time.Minute
+	mapTimeout = 10 * time.Minute
 )
 
 // Map adds a port mapping on m and keeps it alive until c is closed.
 // This function is typically invoked in its own goroutine.
-func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) {
+func Map(m Interface, c <-chan struct{}, protocol string, extport, intport int, name string) {
 	log := log.New("proto", protocol, "extport", extport, "intport", intport, "interface", m)
-	refresh := time.NewTimer(mapUpdateInterval)
+	refresh := time.NewTimer(mapTimeout)
 	defer func() {
 		refresh.Stop()
 		log.Debug("Deleting port mapping")
@@ -121,7 +120,7 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na
 			if err := m.AddMapping(protocol, extport, intport, name, mapTimeout); err != nil {
 				log.Debug("Couldn't add port mapping", "err", err)
 			}
-			refresh.Reset(mapUpdateInterval)
+			refresh.Reset(mapTimeout)
 		}
 	}
 }

+ 60 - 20
p2p/nat/natupnp.go

@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"net"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/huin/goupnp"
@@ -28,12 +29,17 @@ import (
 	"github.com/huin/goupnp/dcps/internetgateway2"
 )
 
-const soapRequestTimeout = 3 * time.Second
+const (
+	soapRequestTimeout = 3 * time.Second
+	rateLimit          = 200 * time.Millisecond
+)
 
 type upnp struct {
-	dev     *goupnp.RootDevice
-	service string
-	client  upnpClient
+	dev         *goupnp.RootDevice
+	service     string
+	client      upnpClient
+	mu          sync.Mutex
+	lastReqTime time.Time
 }
 
 type upnpClient interface {
@@ -43,8 +49,23 @@ type upnpClient interface {
 	GetNATRSIPStatus() (sip bool, nat bool, err error)
 }
 
+func (n *upnp) natEnabled() bool {
+	var ok bool
+	var err error
+	n.withRateLimit(func() error {
+		_, ok, err = n.client.GetNATRSIPStatus()
+		return err
+	})
+	return err == nil && ok
+}
+
 func (n *upnp) ExternalIP() (addr net.IP, err error) {
-	ipString, err := n.client.GetExternalIPAddress()
+	var ipString string
+	n.withRateLimit(func() error {
+		ipString, err = n.client.GetExternalIPAddress()
+		return err
+	})
+
 	if err != nil {
 		return nil, err
 	}
@@ -63,7 +84,10 @@ func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, li
 	protocol = strings.ToUpper(protocol)
 	lifetimeS := uint32(lifetime / time.Second)
 	n.DeleteMapping(protocol, extport, intport)
-	return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
+
+	return n.withRateLimit(func() error {
+		return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
+	})
 }
 
 func (n *upnp) internalAddress() (net.IP, error) {
@@ -90,36 +114,51 @@ func (n *upnp) internalAddress() (net.IP, error) {
 }
 
 func (n *upnp) DeleteMapping(protocol string, extport, intport int) error {
-	return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
+	return n.withRateLimit(func() error {
+		return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
+	})
 }
 
 func (n *upnp) String() string {
 	return "UPNP " + n.service
 }
 
+func (n *upnp) withRateLimit(fn func() error) error {
+	n.mu.Lock()
+	defer n.mu.Unlock()
+
+	lastreq := time.Since(n.lastReqTime)
+	if lastreq < rateLimit {
+		time.Sleep(rateLimit - lastreq)
+	}
+	err := fn()
+	n.lastReqTime = time.Now()
+	return err
+}
+
 // discoverUPnP searches for Internet Gateway Devices
 // and returns the first one it can find on the local network.
 func discoverUPnP() Interface {
 	found := make(chan *upnp, 2)
 	// IGDv1
-	go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
+	go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(sc goupnp.ServiceClient) *upnp {
 		switch sc.Service.ServiceType {
 		case internetgateway1.URN_WANIPConnection_1:
-			return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{ServiceClient: sc}}
+			return &upnp{service: "IGDv1-IP1", client: &internetgateway1.WANIPConnection1{ServiceClient: sc}}
 		case internetgateway1.URN_WANPPPConnection_1:
-			return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{ServiceClient: sc}}
+			return &upnp{service: "IGDv1-PPP1", client: &internetgateway1.WANPPPConnection1{ServiceClient: sc}}
 		}
 		return nil
 	})
 	// IGDv2
-	go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
+	go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(sc goupnp.ServiceClient) *upnp {
 		switch sc.Service.ServiceType {
 		case internetgateway2.URN_WANIPConnection_1:
-			return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{ServiceClient: sc}}
+			return &upnp{service: "IGDv2-IP1", client: &internetgateway2.WANIPConnection1{ServiceClient: sc}}
 		case internetgateway2.URN_WANIPConnection_2:
-			return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{ServiceClient: sc}}
+			return &upnp{service: "IGDv2-IP2", client: &internetgateway2.WANIPConnection2{ServiceClient: sc}}
 		case internetgateway2.URN_WANPPPConnection_1:
-			return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{ServiceClient: sc}}
+			return &upnp{service: "IGDv2-PPP1", client: &internetgateway2.WANPPPConnection1{ServiceClient: sc}}
 		}
 		return nil
 	})
@@ -134,7 +173,7 @@ func discoverUPnP() Interface {
 // finds devices matching the given target and calls matcher for all
 // advertised services of each device. The first non-nil service found
 // is sent into out. If no service matched, nil is sent.
-func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) {
+func discover(out chan<- *upnp, target string, matcher func(goupnp.ServiceClient) *upnp) {
 	devs, err := goupnp.DiscoverDevices(target)
 	if err != nil {
 		out <- nil
@@ -157,16 +196,17 @@ func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice,
 				Service:    service,
 			}
 			sc.SOAPClient.HTTPClient.Timeout = soapRequestTimeout
-			upnp := matcher(devs[i].Root, sc)
+			upnp := matcher(sc)
 			if upnp == nil {
 				return
 			}
+			upnp.dev = devs[i].Root
+
 			// check whether port mapping is enabled
-			if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat {
-				return
+			if upnp.natEnabled() {
+				out <- upnp
+				found = true
 			}
-			out <- upnp
-			found = true
 		})
 	}
 	if !found {