udp.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. // Copyright 2016 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package discv5
  17. import (
  18. "bytes"
  19. "crypto/ecdsa"
  20. "errors"
  21. "fmt"
  22. "net"
  23. "time"
  24. "github.com/ethereum/go-ethereum/common"
  25. "github.com/ethereum/go-ethereum/crypto"
  26. "github.com/ethereum/go-ethereum/logger"
  27. "github.com/ethereum/go-ethereum/logger/glog"
  28. "github.com/ethereum/go-ethereum/p2p/nat"
  29. "github.com/ethereum/go-ethereum/p2p/netutil"
  30. "github.com/ethereum/go-ethereum/rlp"
  31. )
  32. const Version = 4
  33. // Errors
  34. var (
  35. errPacketTooSmall = errors.New("too small")
  36. errBadHash = errors.New("bad hash")
  37. errExpired = errors.New("expired")
  38. errUnsolicitedReply = errors.New("unsolicited reply")
  39. errUnknownNode = errors.New("unknown node")
  40. errTimeout = errors.New("RPC timeout")
  41. errClockWarp = errors.New("reply deadline too far in the future")
  42. errClosed = errors.New("socket closed")
  43. )
  44. // Timeouts
  45. const (
  46. respTimeout = 500 * time.Millisecond
  47. sendTimeout = 500 * time.Millisecond
  48. expiration = 20 * time.Second
  49. ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP
  50. ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
  51. driftThreshold = 10 * time.Second // Allowed clock drift before warning user
  52. )
  53. // RPC request structures
  54. type (
  55. ping struct {
  56. Version uint
  57. From, To rpcEndpoint
  58. Expiration uint64
  59. // v5
  60. Topics []Topic
  61. // Ignore additional fields (for forward compatibility).
  62. Rest []rlp.RawValue `rlp:"tail"`
  63. }
  64. // pong is the reply to ping.
  65. pong struct {
  66. // This field should mirror the UDP envelope address
  67. // of the ping packet, which provides a way to discover the
  68. // the external address (after NAT).
  69. To rpcEndpoint
  70. ReplyTok []byte // This contains the hash of the ping packet.
  71. Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
  72. // v5
  73. TopicHash common.Hash
  74. TicketSerial uint32
  75. WaitPeriods []uint32
  76. // Ignore additional fields (for forward compatibility).
  77. Rest []rlp.RawValue `rlp:"tail"`
  78. }
  79. // findnode is a query for nodes close to the given target.
  80. findnode struct {
  81. Target NodeID // doesn't need to be an actual public key
  82. Expiration uint64
  83. // Ignore additional fields (for forward compatibility).
  84. Rest []rlp.RawValue `rlp:"tail"`
  85. }
  86. // findnode is a query for nodes close to the given target.
  87. findnodeHash struct {
  88. Target common.Hash
  89. Expiration uint64
  90. // Ignore additional fields (for forward compatibility).
  91. Rest []rlp.RawValue `rlp:"tail"`
  92. }
  93. // reply to findnode
  94. neighbors struct {
  95. Nodes []rpcNode
  96. Expiration uint64
  97. // Ignore additional fields (for forward compatibility).
  98. Rest []rlp.RawValue `rlp:"tail"`
  99. }
  100. topicRegister struct {
  101. Topics []Topic
  102. Idx uint
  103. Pong []byte
  104. }
  105. topicQuery struct {
  106. Topic Topic
  107. Expiration uint64
  108. }
  109. // reply to topicQuery
  110. topicNodes struct {
  111. Echo common.Hash
  112. Nodes []rpcNode
  113. }
  114. rpcNode struct {
  115. IP net.IP // len 4 for IPv4 or 16 for IPv6
  116. UDP uint16 // for discovery protocol
  117. TCP uint16 // for RLPx protocol
  118. ID NodeID
  119. }
  120. rpcEndpoint struct {
  121. IP net.IP // len 4 for IPv4 or 16 for IPv6
  122. UDP uint16 // for discovery protocol
  123. TCP uint16 // for RLPx protocol
  124. }
  125. )
  126. const (
  127. macSize = 256 / 8
  128. sigSize = 520 / 8
  129. headSize = macSize + sigSize // space of packet frame data
  130. )
  131. // Neighbors replies are sent across multiple packets to
  132. // stay below the 1280 byte limit. We compute the maximum number
  133. // of entries by stuffing a packet until it grows too large.
  134. var maxNeighbors = func() int {
  135. p := neighbors{Expiration: ^uint64(0)}
  136. maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
  137. for n := 0; ; n++ {
  138. p.Nodes = append(p.Nodes, maxSizeNode)
  139. size, _, err := rlp.EncodeToReader(p)
  140. if err != nil {
  141. // If this ever happens, it will be caught by the unit tests.
  142. panic("cannot encode: " + err.Error())
  143. }
  144. if headSize+size+1 >= 1280 {
  145. return n
  146. }
  147. }
  148. }()
  149. var maxTopicNodes = func() int {
  150. p := topicNodes{}
  151. maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
  152. for n := 0; ; n++ {
  153. p.Nodes = append(p.Nodes, maxSizeNode)
  154. size, _, err := rlp.EncodeToReader(p)
  155. if err != nil {
  156. // If this ever happens, it will be caught by the unit tests.
  157. panic("cannot encode: " + err.Error())
  158. }
  159. if headSize+size+1 >= 1280 {
  160. return n
  161. }
  162. }
  163. }()
  164. func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
  165. ip := addr.IP.To4()
  166. if ip == nil {
  167. ip = addr.IP.To16()
  168. }
  169. return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
  170. }
  171. func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
  172. return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP)
  173. }
  174. func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
  175. if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
  176. return nil, err
  177. }
  178. n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
  179. err := n.validateComplete()
  180. return n, err
  181. }
  182. func nodeToRPC(n *Node) rpcNode {
  183. return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
  184. }
  185. type ingressPacket struct {
  186. remoteID NodeID
  187. remoteAddr *net.UDPAddr
  188. ev nodeEvent
  189. hash []byte
  190. data interface{} // one of the RPC structs
  191. rawData []byte
  192. }
  193. type conn interface {
  194. ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
  195. WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
  196. Close() error
  197. LocalAddr() net.Addr
  198. }
  199. // udp implements the RPC protocol.
  200. type udp struct {
  201. conn conn
  202. priv *ecdsa.PrivateKey
  203. ourEndpoint rpcEndpoint
  204. nat nat.Interface
  205. net *Network
  206. }
  207. // ListenUDP returns a new table that listens for UDP packets on laddr.
  208. func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
  209. transport, err := listenUDP(priv, laddr)
  210. if err != nil {
  211. return nil, err
  212. }
  213. net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict)
  214. if err != nil {
  215. return nil, err
  216. }
  217. transport.net = net
  218. go transport.readLoop()
  219. return net, nil
  220. }
  221. func listenUDP(priv *ecdsa.PrivateKey, laddr string) (*udp, error) {
  222. addr, err := net.ResolveUDPAddr("udp", laddr)
  223. if err != nil {
  224. return nil, err
  225. }
  226. conn, err := net.ListenUDP("udp", addr)
  227. if err != nil {
  228. return nil, err
  229. }
  230. return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(addr, uint16(addr.Port))}, nil
  231. }
  232. func (t *udp) localAddr() *net.UDPAddr {
  233. return t.conn.LocalAddr().(*net.UDPAddr)
  234. }
  235. func (t *udp) Close() {
  236. t.conn.Close()
  237. }
  238. func (t *udp) send(remote *Node, ptype nodeEvent, data interface{}) (hash []byte) {
  239. hash, _ = t.sendPacket(remote.ID, remote.addr(), byte(ptype), data)
  240. return hash
  241. }
  242. func (t *udp) sendPing(remote *Node, toaddr *net.UDPAddr, topics []Topic) (hash []byte) {
  243. hash, _ = t.sendPacket(remote.ID, toaddr, byte(pingPacket), ping{
  244. Version: Version,
  245. From: t.ourEndpoint,
  246. To: makeEndpoint(toaddr, uint16(toaddr.Port)), // TODO: maybe use known TCP port from DB
  247. Expiration: uint64(time.Now().Add(expiration).Unix()),
  248. Topics: topics,
  249. })
  250. return hash
  251. }
  252. func (t *udp) sendFindnode(remote *Node, target NodeID) {
  253. t.sendPacket(remote.ID, remote.addr(), byte(findnodePacket), findnode{
  254. Target: target,
  255. Expiration: uint64(time.Now().Add(expiration).Unix()),
  256. })
  257. }
  258. func (t *udp) sendNeighbours(remote *Node, results []*Node) {
  259. // Send neighbors in chunks with at most maxNeighbors per packet
  260. // to stay below the 1280 byte limit.
  261. p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
  262. for i, result := range results {
  263. p.Nodes = append(p.Nodes, nodeToRPC(result))
  264. if len(p.Nodes) == maxNeighbors || i == len(results)-1 {
  265. t.sendPacket(remote.ID, remote.addr(), byte(neighborsPacket), p)
  266. p.Nodes = p.Nodes[:0]
  267. }
  268. }
  269. }
  270. func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
  271. t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
  272. Target: target,
  273. Expiration: uint64(time.Now().Add(expiration).Unix()),
  274. })
  275. }
  276. func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []byte) {
  277. t.sendPacket(remote.ID, remote.addr(), byte(topicRegisterPacket), topicRegister{
  278. Topics: topics,
  279. Idx: uint(idx),
  280. Pong: pong,
  281. })
  282. }
  283. func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
  284. p := topicNodes{Echo: queryHash}
  285. if len(nodes) == 0 {
  286. t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
  287. return
  288. }
  289. for i, result := range nodes {
  290. if netutil.CheckRelayIP(remote.IP, result.IP) != nil {
  291. continue
  292. }
  293. p.Nodes = append(p.Nodes, nodeToRPC(result))
  294. if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 {
  295. t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
  296. p.Nodes = p.Nodes[:0]
  297. }
  298. }
  299. }
  300. func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
  301. //fmt.Println("sendPacket", nodeEvent(ptype), toaddr.String(), toid.String())
  302. packet, hash, err := encodePacket(t.priv, ptype, req)
  303. if err != nil {
  304. //fmt.Println(err)
  305. return hash, err
  306. }
  307. glog.V(logger.Detail).Infof(">>> %v to %x@%v\n", nodeEvent(ptype), toid[:8], toaddr)
  308. if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
  309. glog.V(logger.Detail).Infoln("UDP send failed:", err)
  310. }
  311. //fmt.Println(err)
  312. return hash, err
  313. }
  314. // zeroed padding space for encodePacket.
  315. var headSpace = make([]byte, headSize)
  316. func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (p, hash []byte, err error) {
  317. b := new(bytes.Buffer)
  318. b.Write(headSpace)
  319. b.WriteByte(ptype)
  320. if err := rlp.Encode(b, req); err != nil {
  321. glog.V(logger.Error).Infoln("error encoding packet:", err)
  322. return nil, nil, err
  323. }
  324. packet := b.Bytes()
  325. sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
  326. if err != nil {
  327. glog.V(logger.Error).Infoln("could not sign packet:", err)
  328. return nil, nil, err
  329. }
  330. copy(packet[macSize:], sig)
  331. // add the hash to the front. Note: this doesn't protect the
  332. // packet in any way.
  333. hash = crypto.Keccak256(packet[macSize:])
  334. copy(packet, hash)
  335. return packet, hash, nil
  336. }
  337. // readLoop runs in its own goroutine. it injects ingress UDP packets
  338. // into the network loop.
  339. func (t *udp) readLoop() {
  340. defer t.conn.Close()
  341. // Discovery packets are defined to be no larger than 1280 bytes.
  342. // Packets larger than this size will be cut at the end and treated
  343. // as invalid because their hash won't match.
  344. buf := make([]byte, 1280)
  345. for {
  346. nbytes, from, err := t.conn.ReadFromUDP(buf)
  347. if netutil.IsTemporaryError(err) {
  348. // Ignore temporary read errors.
  349. glog.V(logger.Debug).Infof("Temporary read error: %v", err)
  350. continue
  351. } else if err != nil {
  352. // Shut down the loop for permament errors.
  353. glog.V(logger.Debug).Infof("Read error: %v", err)
  354. return
  355. }
  356. t.handlePacket(from, buf[:nbytes])
  357. }
  358. }
  359. func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
  360. pkt := ingressPacket{remoteAddr: from}
  361. if err := decodePacket(buf, &pkt); err != nil {
  362. glog.V(logger.Debug).Infof("Bad packet from %v: %v\n", from, err)
  363. //fmt.Println("bad packet", err)
  364. return err
  365. }
  366. t.net.reqReadPacket(pkt)
  367. return nil
  368. }
  369. func decodePacket(buffer []byte, pkt *ingressPacket) error {
  370. if len(buffer) < headSize+1 {
  371. return errPacketTooSmall
  372. }
  373. buf := make([]byte, len(buffer))
  374. copy(buf, buffer)
  375. hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
  376. shouldhash := crypto.Keccak256(buf[macSize:])
  377. if !bytes.Equal(hash, shouldhash) {
  378. return errBadHash
  379. }
  380. fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig)
  381. if err != nil {
  382. return err
  383. }
  384. pkt.rawData = buf
  385. pkt.hash = hash
  386. pkt.remoteID = fromID
  387. switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
  388. case pingPacket:
  389. pkt.data = new(ping)
  390. case pongPacket:
  391. pkt.data = new(pong)
  392. case findnodePacket:
  393. pkt.data = new(findnode)
  394. case neighborsPacket:
  395. pkt.data = new(neighbors)
  396. case findnodeHashPacket:
  397. pkt.data = new(findnodeHash)
  398. case topicRegisterPacket:
  399. pkt.data = new(topicRegister)
  400. case topicQueryPacket:
  401. pkt.data = new(topicQuery)
  402. case topicNodesPacket:
  403. pkt.data = new(topicNodes)
  404. default:
  405. return fmt.Errorf("unknown packet type: %d", sigdata[0])
  406. }
  407. s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0)
  408. err = s.Decode(pkt.data)
  409. return err
  410. }