server.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. package p2p
  2. import (
  3. "blockchain-go/common/gopool"
  4. "blockchain-go/p2p/discover"
  5. "blockchain-go/p2p/enode"
  6. "blockchain-go/p2p/enr"
  7. "blockchain-go/p2p/nat"
  8. "blockchain-go/p2p/rlpx"
  9. "blockchain-go/params"
  10. "crypto/ecdsa"
  11. "errors"
  12. "fmt"
  13. "github.com/ethereum/go-ethereum/crypto"
  14. "github.com/ethereum/go-ethereum/p2p/netutil"
  15. "net"
  16. "sync"
  17. "time"
  18. )
  19. var (
  20. ErrServerStopped = errors.New("server stopped")
  21. )
  22. // Server manages all peer connections.
  23. type Server struct {
  24. Config
  25. //newTransport func(net.Conn, *ecdsa.PublicKey) transport
  26. //newPeerHook func(*Peer)
  27. //listenFunc func(network, addr string) (net.Listener, error)
  28. lock sync.Mutex // protects running
  29. running bool
  30. listener net.Listener
  31. ourHandshake *protoHandshake
  32. loopWG sync.WaitGroup // loop, listenLoop
  33. //peerFeed event.Feed
  34. //log log.Logger
  35. //nodedb *enode.DB
  36. localnode *enode.LocalNode
  37. ntab *discover.UDPv4
  38. //DiscV5 *discover.UDPv5
  39. discmix *enode.FairMix
  40. dialsched *dialScheduler
  41. quit chan struct{}
  42. //addtrusted chan *enode.Node
  43. //removetrusted chan *enode.Node
  44. //peerOp chan peerOpFunc
  45. //peerOpDone chan struct{}
  46. //delpeer chan peerDrop
  47. checkpointPostHandshake chan *conn
  48. checkpointAddPeer chan *conn
  49. //inboundHistory expHeap
  50. }
  51. func (server *Server) Start() (err error) {
  52. server.quit = make(chan struct{})
  53. server.checkpointPostHandshake = make(chan *conn)
  54. server.checkpointAddPeer = make(chan *conn)
  55. if err := server.setupBootstrapNodes(); err != nil {
  56. return err
  57. }
  58. if err := server.setupLocalNode(); err != nil {
  59. return err
  60. }
  61. if err := server.setupListening(); err != nil {
  62. return err
  63. }
  64. if err := server.setupDiscovery(); err != nil {
  65. return err
  66. }
  67. server.setupDialScheduler()
  68. return nil
  69. }
  70. // 配置节点发现逻辑
  71. func (server *Server) setupDiscovery() (err error) {
  72. server.discmix = enode.NewFairMix(discmixTimeout)
  73. // 添加特定协议的发现源。
  74. added := make(map[string]bool)
  75. for _, proto := range server.Protocols {
  76. if proto.DialCandidates != nil && !added[proto.Name] {
  77. server.discmix.AddSource(proto.DialCandidates)
  78. added[proto.Name] = true
  79. }
  80. }
  81. addr, err := net.ResolveUDPAddr("udp", server.ListenAddr)
  82. if err != nil {
  83. return err
  84. }
  85. conn, err := net.ListenUDP("udp", addr)
  86. if err != nil {
  87. return err
  88. }
  89. realAddr := conn.LocalAddr().(*net.UDPAddr)
  90. fmt.Printf("UDP listener up, addr: %v.", realAddr)
  91. if server.NAT != nil {
  92. if !realAddr.IP.IsLoopback() {
  93. server.loopWG.Add(1)
  94. gopool.Submit(func() {
  95. nat.Map(server.NAT, server.quit, "udp", realAddr.Port, realAddr.Port, "ethereum discovery")
  96. server.loopWG.Done()
  97. })
  98. }
  99. }
  100. server.localnode.SetFallbackUDP(realAddr.Port)
  101. // 设置V4的发现协议
  102. var unhandled chan discover.ReadPacket
  103. //var sconn *sharedUDPConn
  104. //if server.DiscoveryV5 {
  105. // unhandled = make(chan discover.ReadPacket, 100)
  106. // sconn = &sharedUDPConn{conn, unhandled}
  107. //}
  108. cfg := discover.Config{
  109. PrivateKey: server.PrivateKey,
  110. NetRestrict: server.NetRestrict,
  111. Bootnodes: server.BootstrapNodes,
  112. Unhandled: unhandled,
  113. //Log: server.log,
  114. }
  115. ntab, err := discover.ListenV4(conn, server.localnode, cfg)
  116. if err != nil {
  117. return err
  118. }
  119. server.ntab = ntab
  120. server.discmix.AddSource(ntab.RandomNodes())
  121. return nil
  122. }
  123. func (server *Server) setupBootstrapNodes() (err error) {
  124. urls := params.MainnetBootNodes
  125. server.BootstrapNodes = make([]*enode.Node, 0, len(urls))
  126. for _, url := range urls {
  127. if url != "" {
  128. node, err := enode.Parse(enode.ValidSchemes, url)
  129. if err != nil {
  130. return err
  131. }
  132. server.BootstrapNodes = append(server.BootstrapNodes, node)
  133. }
  134. }
  135. return nil
  136. }
  137. // 设置拨号调度器
  138. func (server *Server) setupDialScheduler() {
  139. config := dialConfig{
  140. self: server.localnode.ID(),
  141. maxDialPeers: server.maxDialedConns(),
  142. maxActiveDials: server.MaxPendingPeers,
  143. log: server.Logger,
  144. netRestrict: server.NetRestrict,
  145. dialer: server.Dialer,
  146. clock: server.clock,
  147. }
  148. if server.ntab != nil {
  149. config.resolver = server.ntab
  150. }
  151. if config.dialer == nil {
  152. config.dialer = tcpDialer{&net.Dialer{Timeout: defaultDialTimeout}}
  153. }
  154. server.dialsched = newDialScheduler(config, server.discmix, server.SetupConn)
  155. //for _, n := range srv.StaticNodes {
  156. // srv.dialsched.addStatic(n)
  157. //}
  158. }
  159. func (server *Server) maxDialedConns() (limit int) {
  160. if server.NoDial || server.MaxPeers == 0 {
  161. return 0
  162. }
  163. if server.DialRatio == 0 {
  164. limit = server.MaxPeers / defaultDialRatio
  165. } else {
  166. limit = server.MaxPeers / server.DialRatio
  167. }
  168. if limit == 0 {
  169. limit = 1
  170. }
  171. return limit
  172. }
  173. // 配置本地节点
  174. func (server *Server) setupLocalNode() (err error) {
  175. // 创建握手所需对象
  176. publicKey := crypto.FromECDSAPub(&server.PrivateKey.PublicKey)
  177. server.ourHandshake = &protoHandshake{
  178. Version: baseProtocolVersion,
  179. Name: server.Name,
  180. ID: publicKey[1:],
  181. }
  182. // 创建本地节点
  183. server.localnode = enode.NewLocalNode(server.PrivateKey)
  184. server.localnode.SetFallbackIP(net.IP{127, 0, 0, 1})
  185. // 配置本地静态IP
  186. ip, _ := server.NAT.ExternalIP()
  187. server.localnode.SetStaticIP(ip)
  188. return nil
  189. }
  190. // 监听器
  191. func (server *Server) setupListening() (err error) {
  192. listener, err := net.Listen("tcp", server.ListenAddr)
  193. if err != nil {
  194. return err
  195. }
  196. server.listener = listener
  197. server.ListenAddr = listener.Addr().String()
  198. if tcp, ok := listener.Addr().(*net.TCPAddr); ok {
  199. server.localnode.Set(enr.TCP(tcp.Port))
  200. if !tcp.IP.IsLoopback() && server.NAT != nil {
  201. server.loopWG.Add(1)
  202. gopool.Submit(func() {
  203. nat.Map(server.NAT, server.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p")
  204. server.loopWG.Done()
  205. })
  206. }
  207. }
  208. server.loopWG.Add(1)
  209. go server.listenLoop()
  210. return nil
  211. }
  212. func (server *Server) listenLoop() {
  213. fmt.Printf("TCP Listener up, addr: %v.", server.listener.Addr())
  214. tokens := defaultMaxPendingPeers
  215. slots := make(chan struct{}, tokens)
  216. for i := 0; i < tokens; i++ {
  217. slots <- struct{}{}
  218. }
  219. defer server.loopWG.Done()
  220. defer func() {
  221. for i := 0; i < cap(slots); i++ {
  222. <-slots
  223. }
  224. }()
  225. for {
  226. <-slots
  227. var (
  228. fd net.Conn
  229. err error
  230. lastLogTime time.Time
  231. )
  232. // accept处理
  233. for {
  234. fd, err = server.listener.Accept()
  235. if netutil.IsTemporaryError(err) {
  236. if time.Since(lastLogTime) > 1*time.Second {
  237. fmt.Errorf("temporary read error, err: %v", err)
  238. lastLogTime = time.Now()
  239. }
  240. time.Sleep(time.Millisecond * 200)
  241. continue
  242. } else if err != nil {
  243. fmt.Errorf("read error, err: %v", err)
  244. slots <- struct{}{}
  245. return
  246. }
  247. break
  248. }
  249. // accept成功的处理
  250. remoteIP := netutil.AddrIP(fd.RemoteAddr())
  251. // TODO 检查此IP是是否能加入本地节点的链接
  252. //if err := server.checkInboundConn(remoteIP); err != nil {
  253. // srv.log.Debug("Rejected inbound connection", "addr", fd.RemoteAddr(), "err", err)
  254. // fd.Close()
  255. // slots <- struct{}{}
  256. // continue
  257. //}
  258. if remoteIP != nil {
  259. var addr *net.TCPAddr
  260. if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok {
  261. addr = tcp
  262. }
  263. fd = newMeteredConn(fd, true, addr)
  264. fmt.Printf("Accepted connection, addr: %v.", fd.RemoteAddr())
  265. }
  266. gopool.Submit(func() {
  267. server.SetupConn(fd, inboundConn, nil)
  268. slots <- struct{}{}
  269. })
  270. }
  271. }
  272. func (server *Server) newRLPX(conn net.Conn, dialDest *ecdsa.PublicKey) transport {
  273. return &rlpxTransport{conn: rlpx.NewConn(conn, dialDest)}
  274. }
  275. func (server *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error {
  276. c := &conn{fd: fd, flags: flags, cont: make(chan error)}
  277. if dialDest == nil {
  278. c.transport = server.newRLPX(fd, nil)
  279. } else {
  280. c.transport = server.newRLPX(fd, dialDest.Pubkey())
  281. }
  282. err := server.setupConn(c, dialDest)
  283. if err != nil {
  284. c.close(err)
  285. }
  286. return err
  287. }
  288. func (server *Server) setupConn(c *conn, dialDest *enode.Node) error {
  289. remotePubkey, err := c.doEncHandshake(server.PrivateKey)
  290. if err != nil {
  291. return err
  292. }
  293. // 将connection转换成node
  294. c.node = enode.NodeFromConn(remotePubkey, c.fd)
  295. fmt.Printf("id: %v, addr: %v, conn: %v", c.node.ID(), c.fd.RemoteAddr(), c.flags)
  296. // 检查是否需要握手
  297. err = server.checkpoint(c, server.checkpointPostHandshake)
  298. if err != nil {
  299. return err
  300. }
  301. // 进行握手
  302. phs, err := c.doProtoHandshake(server.ourHandshake)
  303. if err != nil {
  304. return err
  305. }
  306. c.caps, c.name = phs.Caps, phs.Name
  307. // 将此链接放入addPeer的检查点
  308. err = server.checkpoint(c, server.checkpointAddPeer)
  309. if err != nil {
  310. return err
  311. }
  312. return nil
  313. }
  314. func (server *Server) checkpoint(c *conn, stage chan<- *conn) error {
  315. select {
  316. case stage <- c:
  317. case <-server.quit:
  318. return ErrServerStopped
  319. }
  320. return <-c.cont
  321. }