udp.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. // Copyright 2015 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package discover
  17. import (
  18. "bytes"
  19. "container/list"
  20. "crypto/ecdsa"
  21. "errors"
  22. "fmt"
  23. "net"
  24. "time"
  25. "github.com/ethereum/go-ethereum/crypto"
  26. "github.com/ethereum/go-ethereum/logger"
  27. "github.com/ethereum/go-ethereum/logger/glog"
  28. "github.com/ethereum/go-ethereum/p2p/nat"
  29. "github.com/ethereum/go-ethereum/rlp"
  30. )
  31. const Version = 4
  32. // Errors
  33. var (
  34. errPacketTooSmall = errors.New("too small")
  35. errBadHash = errors.New("bad hash")
  36. errExpired = errors.New("expired")
  37. errUnsolicitedReply = errors.New("unsolicited reply")
  38. errUnknownNode = errors.New("unknown node")
  39. errTimeout = errors.New("RPC timeout")
  40. errClockWarp = errors.New("reply deadline too far in the future")
  41. errClosed = errors.New("socket closed")
  42. )
  43. // Timeouts
  44. const (
  45. respTimeout = 500 * time.Millisecond
  46. sendTimeout = 500 * time.Millisecond
  47. expiration = 20 * time.Second
  48. ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP
  49. ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
  50. driftThreshold = 10 * time.Second // Allowed clock drift before warning user
  51. )
  52. // RPC packet types
  53. const (
  54. pingPacket = iota + 1 // zero is 'reserved'
  55. pongPacket
  56. findnodePacket
  57. neighborsPacket
  58. )
  59. // RPC request structures
  60. type (
  61. ping struct {
  62. Version uint
  63. From, To rpcEndpoint
  64. Expiration uint64
  65. // Ignore additional fields (for forward compatibility).
  66. Rest []rlp.RawValue `rlp:"tail"`
  67. }
  68. // pong is the reply to ping.
  69. pong struct {
  70. // This field should mirror the UDP envelope address
  71. // of the ping packet, which provides a way to discover the
  72. // the external address (after NAT).
  73. To rpcEndpoint
  74. ReplyTok []byte // This contains the hash of the ping packet.
  75. Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
  76. // Ignore additional fields (for forward compatibility).
  77. Rest []rlp.RawValue `rlp:"tail"`
  78. }
  79. // findnode is a query for nodes close to the given target.
  80. findnode struct {
  81. Target NodeID // doesn't need to be an actual public key
  82. Expiration uint64
  83. // Ignore additional fields (for forward compatibility).
  84. Rest []rlp.RawValue `rlp:"tail"`
  85. }
  86. // reply to findnode
  87. neighbors struct {
  88. Nodes []rpcNode
  89. Expiration uint64
  90. // Ignore additional fields (for forward compatibility).
  91. Rest []rlp.RawValue `rlp:"tail"`
  92. }
  93. rpcNode struct {
  94. IP net.IP // len 4 for IPv4 or 16 for IPv6
  95. UDP uint16 // for discovery protocol
  96. TCP uint16 // for RLPx protocol
  97. ID NodeID
  98. }
  99. rpcEndpoint struct {
  100. IP net.IP // len 4 for IPv4 or 16 for IPv6
  101. UDP uint16 // for discovery protocol
  102. TCP uint16 // for RLPx protocol
  103. }
  104. )
  105. func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
  106. ip := addr.IP.To4()
  107. if ip == nil {
  108. ip = addr.IP.To16()
  109. }
  110. return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
  111. }
  112. func nodeFromRPC(rn rpcNode) (*Node, error) {
  113. // TODO: don't accept localhost, LAN addresses from internet hosts
  114. n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
  115. err := n.validateComplete()
  116. return n, err
  117. }
  118. func nodeToRPC(n *Node) rpcNode {
  119. return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
  120. }
  121. type packet interface {
  122. handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
  123. }
  124. type conn interface {
  125. ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
  126. WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
  127. Close() error
  128. LocalAddr() net.Addr
  129. }
  130. // udp implements the RPC protocol.
  131. type udp struct {
  132. conn conn
  133. priv *ecdsa.PrivateKey
  134. ourEndpoint rpcEndpoint
  135. addpending chan *pending
  136. gotreply chan reply
  137. closing chan struct{}
  138. nat nat.Interface
  139. *Table
  140. }
  141. // pending represents a pending reply.
  142. //
  143. // some implementations of the protocol wish to send more than one
  144. // reply packet to findnode. in general, any neighbors packet cannot
  145. // be matched up with a specific findnode packet.
  146. //
  147. // our implementation handles this by storing a callback function for
  148. // each pending reply. incoming packets from a node are dispatched
  149. // to all the callback functions for that node.
  150. type pending struct {
  151. // these fields must match in the reply.
  152. from NodeID
  153. ptype byte
  154. // time when the request must complete
  155. deadline time.Time
  156. // callback is called when a matching reply arrives. if it returns
  157. // true, the callback is removed from the pending reply queue.
  158. // if it returns false, the reply is considered incomplete and
  159. // the callback will be invoked again for the next matching reply.
  160. callback func(resp interface{}) (done bool)
  161. // errc receives nil when the callback indicates completion or an
  162. // error if no further reply is received within the timeout.
  163. errc chan<- error
  164. }
  165. type reply struct {
  166. from NodeID
  167. ptype byte
  168. data interface{}
  169. // loop indicates whether there was
  170. // a matching request by sending on this channel.
  171. matched chan<- bool
  172. }
  173. // ListenUDP returns a new table that listens for UDP packets on laddr.
  174. func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) {
  175. addr, err := net.ResolveUDPAddr("udp", laddr)
  176. if err != nil {
  177. return nil, err
  178. }
  179. conn, err := net.ListenUDP("udp", addr)
  180. if err != nil {
  181. return nil, err
  182. }
  183. tab, _, err := newUDP(priv, conn, natm, nodeDBPath)
  184. if err != nil {
  185. return nil, err
  186. }
  187. glog.V(logger.Info).Infoln("Listening,", tab.self)
  188. return tab, nil
  189. }
  190. func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) {
  191. udp := &udp{
  192. conn: c,
  193. priv: priv,
  194. closing: make(chan struct{}),
  195. gotreply: make(chan reply),
  196. addpending: make(chan *pending),
  197. }
  198. realaddr := c.LocalAddr().(*net.UDPAddr)
  199. if natm != nil {
  200. if !realaddr.IP.IsLoopback() {
  201. go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
  202. }
  203. // TODO: react to external IP changes over time.
  204. if ext, err := natm.ExternalIP(); err == nil {
  205. realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
  206. }
  207. }
  208. // TODO: separate TCP port
  209. udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
  210. tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
  211. if err != nil {
  212. return nil, nil, err
  213. }
  214. udp.Table = tab
  215. go udp.loop()
  216. go udp.readLoop()
  217. return udp.Table, udp, nil
  218. }
  219. func (t *udp) close() {
  220. close(t.closing)
  221. t.conn.Close()
  222. // TODO: wait for the loops to end.
  223. }
  224. // ping sends a ping message to the given node and waits for a reply.
  225. func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
  226. // TODO: maybe check for ReplyTo field in callback to measure RTT
  227. errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
  228. t.send(toaddr, pingPacket, ping{
  229. Version: Version,
  230. From: t.ourEndpoint,
  231. To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
  232. Expiration: uint64(time.Now().Add(expiration).Unix()),
  233. })
  234. return <-errc
  235. }
  236. func (t *udp) waitping(from NodeID) error {
  237. return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
  238. }
  239. // findnode sends a findnode request to the given node and waits until
  240. // the node has sent up to k neighbors.
  241. func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
  242. nodes := make([]*Node, 0, bucketSize)
  243. nreceived := 0
  244. errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
  245. reply := r.(*neighbors)
  246. for _, rn := range reply.Nodes {
  247. nreceived++
  248. if n, err := nodeFromRPC(rn); err == nil {
  249. nodes = append(nodes, n)
  250. }
  251. }
  252. return nreceived >= bucketSize
  253. })
  254. t.send(toaddr, findnodePacket, findnode{
  255. Target: target,
  256. Expiration: uint64(time.Now().Add(expiration).Unix()),
  257. })
  258. err := <-errc
  259. return nodes, err
  260. }
  261. // pending adds a reply callback to the pending reply queue.
  262. // see the documentation of type pending for a detailed explanation.
  263. func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
  264. ch := make(chan error, 1)
  265. p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
  266. select {
  267. case t.addpending <- p:
  268. // loop will handle it
  269. case <-t.closing:
  270. ch <- errClosed
  271. }
  272. return ch
  273. }
  274. func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
  275. matched := make(chan bool, 1)
  276. select {
  277. case t.gotreply <- reply{from, ptype, req, matched}:
  278. // loop will handle it
  279. return <-matched
  280. case <-t.closing:
  281. return false
  282. }
  283. }
  284. // loop runs in its own goroutine. it keeps track of
  285. // the refresh timer and the pending reply queue.
  286. func (t *udp) loop() {
  287. var (
  288. plist = list.New()
  289. timeout = time.NewTimer(0)
  290. nextTimeout *pending // head of plist when timeout was last reset
  291. contTimeouts = 0 // number of continuous timeouts to do NTP checks
  292. ntpWarnTime = time.Unix(0, 0)
  293. )
  294. <-timeout.C // ignore first timeout
  295. defer timeout.Stop()
  296. resetTimeout := func() {
  297. if plist.Front() == nil || nextTimeout == plist.Front().Value {
  298. return
  299. }
  300. // Start the timer so it fires when the next pending reply has expired.
  301. now := time.Now()
  302. for el := plist.Front(); el != nil; el = el.Next() {
  303. nextTimeout = el.Value.(*pending)
  304. if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout {
  305. timeout.Reset(dist)
  306. return
  307. }
  308. // Remove pending replies whose deadline is too far in the
  309. // future. These can occur if the system clock jumped
  310. // backwards after the deadline was assigned.
  311. nextTimeout.errc <- errClockWarp
  312. plist.Remove(el)
  313. }
  314. nextTimeout = nil
  315. timeout.Stop()
  316. }
  317. for {
  318. resetTimeout()
  319. select {
  320. case <-t.closing:
  321. for el := plist.Front(); el != nil; el = el.Next() {
  322. el.Value.(*pending).errc <- errClosed
  323. }
  324. return
  325. case p := <-t.addpending:
  326. p.deadline = time.Now().Add(respTimeout)
  327. plist.PushBack(p)
  328. case r := <-t.gotreply:
  329. var matched bool
  330. for el := plist.Front(); el != nil; el = el.Next() {
  331. p := el.Value.(*pending)
  332. if p.from == r.from && p.ptype == r.ptype {
  333. matched = true
  334. // Remove the matcher if its callback indicates
  335. // that all replies have been received. This is
  336. // required for packet types that expect multiple
  337. // reply packets.
  338. if p.callback(r.data) {
  339. p.errc <- nil
  340. plist.Remove(el)
  341. }
  342. // Reset the continuous timeout counter (time drift detection)
  343. contTimeouts = 0
  344. }
  345. }
  346. r.matched <- matched
  347. case now := <-timeout.C:
  348. nextTimeout = nil
  349. // Notify and remove callbacks whose deadline is in the past.
  350. for el := plist.Front(); el != nil; el = el.Next() {
  351. p := el.Value.(*pending)
  352. if now.After(p.deadline) || now.Equal(p.deadline) {
  353. p.errc <- errTimeout
  354. plist.Remove(el)
  355. contTimeouts++
  356. }
  357. }
  358. // If we've accumulated too many timeouts, do an NTP time sync check
  359. if contTimeouts > ntpFailureThreshold {
  360. if time.Since(ntpWarnTime) >= ntpWarningCooldown {
  361. ntpWarnTime = time.Now()
  362. go checkClockDrift()
  363. }
  364. contTimeouts = 0
  365. }
  366. }
  367. }
  368. }
  369. const (
  370. macSize = 256 / 8
  371. sigSize = 520 / 8
  372. headSize = macSize + sigSize // space of packet frame data
  373. )
  374. var (
  375. headSpace = make([]byte, headSize)
  376. // Neighbors replies are sent across multiple packets to
  377. // stay below the 1280 byte limit. We compute the maximum number
  378. // of entries by stuffing a packet until it grows too large.
  379. maxNeighbors int
  380. )
  381. func init() {
  382. p := neighbors{Expiration: ^uint64(0)}
  383. maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
  384. for n := 0; ; n++ {
  385. p.Nodes = append(p.Nodes, maxSizeNode)
  386. size, _, err := rlp.EncodeToReader(p)
  387. if err != nil {
  388. // If this ever happens, it will be caught by the unit tests.
  389. panic("cannot encode: " + err.Error())
  390. }
  391. if headSize+size+1 >= 1280 {
  392. maxNeighbors = n
  393. break
  394. }
  395. }
  396. }
  397. func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
  398. packet, err := encodePacket(t.priv, ptype, req)
  399. if err != nil {
  400. return err
  401. }
  402. glog.V(logger.Detail).Infof(">>> %v %T\n", toaddr, req)
  403. if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
  404. glog.V(logger.Detail).Infoln("UDP send failed:", err)
  405. }
  406. return err
  407. }
  408. func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
  409. b := new(bytes.Buffer)
  410. b.Write(headSpace)
  411. b.WriteByte(ptype)
  412. if err := rlp.Encode(b, req); err != nil {
  413. glog.V(logger.Error).Infoln("error encoding packet:", err)
  414. return nil, err
  415. }
  416. packet := b.Bytes()
  417. sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
  418. if err != nil {
  419. glog.V(logger.Error).Infoln("could not sign packet:", err)
  420. return nil, err
  421. }
  422. copy(packet[macSize:], sig)
  423. // add the hash to the front. Note: this doesn't protect the
  424. // packet in any way. Our public key will be part of this hash in
  425. // The future.
  426. copy(packet, crypto.Keccak256(packet[macSize:]))
  427. return packet, nil
  428. }
  429. func isTemporaryError(err error) bool {
  430. tempErr, ok := err.(interface {
  431. Temporary() bool
  432. })
  433. return ok && tempErr.Temporary() || isPacketTooBig(err)
  434. }
  435. // readLoop runs in its own goroutine. it handles incoming UDP packets.
  436. func (t *udp) readLoop() {
  437. defer t.conn.Close()
  438. // Discovery packets are defined to be no larger than 1280 bytes.
  439. // Packets larger than this size will be cut at the end and treated
  440. // as invalid because their hash won't match.
  441. buf := make([]byte, 1280)
  442. for {
  443. nbytes, from, err := t.conn.ReadFromUDP(buf)
  444. if isTemporaryError(err) {
  445. // Ignore temporary read errors.
  446. glog.V(logger.Debug).Infof("Temporary read error: %v", err)
  447. continue
  448. } else if err != nil {
  449. // Shut down the loop for permament errors.
  450. glog.V(logger.Debug).Infof("Read error: %v", err)
  451. return
  452. }
  453. t.handlePacket(from, buf[:nbytes])
  454. }
  455. }
  456. func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
  457. packet, fromID, hash, err := decodePacket(buf)
  458. if err != nil {
  459. glog.V(logger.Debug).Infof("Bad packet from %v: %v\n", from, err)
  460. return err
  461. }
  462. status := "ok"
  463. if err = packet.handle(t, from, fromID, hash); err != nil {
  464. status = err.Error()
  465. }
  466. glog.V(logger.Detail).Infof("<<< %v %T: %s\n", from, packet, status)
  467. return err
  468. }
  469. func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
  470. if len(buf) < headSize+1 {
  471. return nil, NodeID{}, nil, errPacketTooSmall
  472. }
  473. hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
  474. shouldhash := crypto.Keccak256(buf[macSize:])
  475. if !bytes.Equal(hash, shouldhash) {
  476. return nil, NodeID{}, nil, errBadHash
  477. }
  478. fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig)
  479. if err != nil {
  480. return nil, NodeID{}, hash, err
  481. }
  482. var req packet
  483. switch ptype := sigdata[0]; ptype {
  484. case pingPacket:
  485. req = new(ping)
  486. case pongPacket:
  487. req = new(pong)
  488. case findnodePacket:
  489. req = new(findnode)
  490. case neighborsPacket:
  491. req = new(neighbors)
  492. default:
  493. return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
  494. }
  495. s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0)
  496. err = s.Decode(req)
  497. return req, fromID, hash, err
  498. }
  499. func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
  500. if expired(req.Expiration) {
  501. return errExpired
  502. }
  503. t.send(from, pongPacket, pong{
  504. To: makeEndpoint(from, req.From.TCP),
  505. ReplyTok: mac,
  506. Expiration: uint64(time.Now().Add(expiration).Unix()),
  507. })
  508. if !t.handleReply(fromID, pingPacket, req) {
  509. // Note: we're ignoring the provided IP address right now
  510. go t.bond(true, fromID, from, req.From.TCP)
  511. }
  512. return nil
  513. }
  514. func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
  515. if expired(req.Expiration) {
  516. return errExpired
  517. }
  518. if !t.handleReply(fromID, pongPacket, req) {
  519. return errUnsolicitedReply
  520. }
  521. return nil
  522. }
  523. func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
  524. if expired(req.Expiration) {
  525. return errExpired
  526. }
  527. if t.db.node(fromID) == nil {
  528. // No bond exists, we don't process the packet. This prevents
  529. // an attack vector where the discovery protocol could be used
  530. // to amplify traffic in a DDOS attack. A malicious actor
  531. // would send a findnode request with the IP address and UDP
  532. // port of the target as the source address. The recipient of
  533. // the findnode packet would then send a neighbors packet
  534. // (which is a much bigger packet than findnode) to the victim.
  535. return errUnknownNode
  536. }
  537. target := crypto.Keccak256Hash(req.Target[:])
  538. t.mutex.Lock()
  539. closest := t.closest(target, bucketSize).entries
  540. t.mutex.Unlock()
  541. p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
  542. // Send neighbors in chunks with at most maxNeighbors per packet
  543. // to stay below the 1280 byte limit.
  544. for i, n := range closest {
  545. p.Nodes = append(p.Nodes, nodeToRPC(n))
  546. if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
  547. t.send(from, neighborsPacket, p)
  548. p.Nodes = p.Nodes[:0]
  549. }
  550. }
  551. return nil
  552. }
  553. func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
  554. if expired(req.Expiration) {
  555. return errExpired
  556. }
  557. if !t.handleReply(fromID, neighborsPacket, req) {
  558. return errUnsolicitedReply
  559. }
  560. return nil
  561. }
  562. func expired(ts uint64) bool {
  563. return time.Unix(int64(ts), 0).Before(time.Now())
  564. }