server_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. // Copyright 2014 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package p2p
  17. import (
  18. "crypto/ecdsa"
  19. "errors"
  20. "math/rand"
  21. "net"
  22. "reflect"
  23. "testing"
  24. "time"
  25. "github.com/ethereum/go-ethereum/crypto"
  26. "github.com/ethereum/go-ethereum/crypto/sha3"
  27. "github.com/ethereum/go-ethereum/p2p/discover"
  28. )
  29. func init() {
  30. // log.Root().SetHandler(log.LvlFilterHandler(log.LvlError, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
  31. }
  32. type testTransport struct {
  33. id discover.NodeID
  34. *rlpx
  35. closeErr error
  36. }
  37. func newTestTransport(id discover.NodeID, fd net.Conn) transport {
  38. wrapped := newRLPX(fd).(*rlpx)
  39. wrapped.rw = newRLPXFrameRW(fd, secrets{
  40. MAC: zero16,
  41. AES: zero16,
  42. IngressMAC: sha3.NewKeccak256(),
  43. EgressMAC: sha3.NewKeccak256(),
  44. })
  45. return &testTransport{id: id, rlpx: wrapped}
  46. }
  47. func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  48. return c.id, nil
  49. }
  50. func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  51. return &protoHandshake{ID: c.id, Name: "test"}, nil
  52. }
  53. func (c *testTransport) close(err error) {
  54. c.rlpx.fd.Close()
  55. c.closeErr = err
  56. }
  57. func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
  58. config := Config{
  59. Name: "test",
  60. MaxPeers: 10,
  61. ListenAddr: "127.0.0.1:0",
  62. PrivateKey: newkey(),
  63. }
  64. server := &Server{
  65. Config: config,
  66. newPeerHook: pf,
  67. newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
  68. }
  69. if err := server.Start(); err != nil {
  70. t.Fatalf("Could not start server: %v", err)
  71. }
  72. return server
  73. }
  74. func TestServerListen(t *testing.T) {
  75. // start the test server
  76. connected := make(chan *Peer)
  77. remid := randomID()
  78. srv := startTestServer(t, remid, func(p *Peer) {
  79. if p.ID() != remid {
  80. t.Error("peer func called with wrong node id")
  81. }
  82. if p == nil {
  83. t.Error("peer func called with nil conn")
  84. }
  85. connected <- p
  86. })
  87. defer close(connected)
  88. defer srv.Stop()
  89. // dial the test server
  90. conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
  91. if err != nil {
  92. t.Fatalf("could not dial: %v", err)
  93. }
  94. defer conn.Close()
  95. select {
  96. case peer := <-connected:
  97. if peer.LocalAddr().String() != conn.RemoteAddr().String() {
  98. t.Errorf("peer started with wrong conn: got %v, want %v",
  99. peer.LocalAddr(), conn.RemoteAddr())
  100. }
  101. peers := srv.Peers()
  102. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  103. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  104. }
  105. case <-time.After(1 * time.Second):
  106. t.Error("server did not accept within one second")
  107. }
  108. }
  109. func TestServerDial(t *testing.T) {
  110. // run a one-shot TCP server to handle the connection.
  111. listener, err := net.Listen("tcp", "127.0.0.1:0")
  112. if err != nil {
  113. t.Fatalf("could not setup listener: %v", err)
  114. }
  115. defer listener.Close()
  116. accepted := make(chan net.Conn)
  117. go func() {
  118. conn, err := listener.Accept()
  119. if err != nil {
  120. t.Error("accept error:", err)
  121. return
  122. }
  123. accepted <- conn
  124. }()
  125. // start the server
  126. connected := make(chan *Peer)
  127. remid := randomID()
  128. srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
  129. defer close(connected)
  130. defer srv.Stop()
  131. // tell the server to connect
  132. tcpAddr := listener.Addr().(*net.TCPAddr)
  133. srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
  134. select {
  135. case conn := <-accepted:
  136. defer conn.Close()
  137. select {
  138. case peer := <-connected:
  139. if peer.ID() != remid {
  140. t.Errorf("peer has wrong id")
  141. }
  142. if peer.Name() != "test" {
  143. t.Errorf("peer has wrong name")
  144. }
  145. if peer.RemoteAddr().String() != conn.LocalAddr().String() {
  146. t.Errorf("peer started with wrong conn: got %v, want %v",
  147. peer.RemoteAddr(), conn.LocalAddr())
  148. }
  149. peers := srv.Peers()
  150. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  151. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  152. }
  153. case <-time.After(1 * time.Second):
  154. t.Error("server did not launch peer within one second")
  155. }
  156. case <-time.After(1 * time.Second):
  157. t.Error("server did not connect within one second")
  158. }
  159. }
  160. // This test checks that tasks generated by dialstate are
  161. // actually executed and taskdone is called for them.
  162. func TestServerTaskScheduling(t *testing.T) {
  163. var (
  164. done = make(chan *testTask)
  165. quit, returned = make(chan struct{}), make(chan struct{})
  166. tc = 0
  167. tg = taskgen{
  168. newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
  169. tc++
  170. return []task{&testTask{index: tc - 1}}
  171. },
  172. doneFunc: func(t task) {
  173. select {
  174. case done <- t.(*testTask):
  175. case <-quit:
  176. }
  177. },
  178. }
  179. )
  180. // The Server in this test isn't actually running
  181. // because we're only interested in what run does.
  182. srv := &Server{
  183. Config: Config{MaxPeers: 10},
  184. quit: make(chan struct{}),
  185. ntab: fakeTable{},
  186. running: true,
  187. }
  188. srv.loopWG.Add(1)
  189. go func() {
  190. srv.run(tg)
  191. close(returned)
  192. }()
  193. var gotdone []*testTask
  194. for i := 0; i < 100; i++ {
  195. gotdone = append(gotdone, <-done)
  196. }
  197. for i, task := range gotdone {
  198. if task.index != i {
  199. t.Errorf("task %d has wrong index, got %d", i, task.index)
  200. break
  201. }
  202. if !task.called {
  203. t.Errorf("task %d was not called", i)
  204. break
  205. }
  206. }
  207. close(quit)
  208. srv.Stop()
  209. select {
  210. case <-returned:
  211. case <-time.After(500 * time.Millisecond):
  212. t.Error("Server.run did not return within 500ms")
  213. }
  214. }
  215. // This test checks that Server doesn't drop tasks,
  216. // even if newTasks returns more than the maximum number of tasks.
  217. func TestServerManyTasks(t *testing.T) {
  218. alltasks := make([]task, 300)
  219. for i := range alltasks {
  220. alltasks[i] = &testTask{index: i}
  221. }
  222. var (
  223. srv = &Server{quit: make(chan struct{}), ntab: fakeTable{}, running: true}
  224. done = make(chan *testTask)
  225. start, end = 0, 0
  226. )
  227. defer srv.Stop()
  228. srv.loopWG.Add(1)
  229. go srv.run(taskgen{
  230. newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
  231. start, end = end, end+maxActiveDialTasks+10
  232. if end > len(alltasks) {
  233. end = len(alltasks)
  234. }
  235. return alltasks[start:end]
  236. },
  237. doneFunc: func(tt task) {
  238. done <- tt.(*testTask)
  239. },
  240. })
  241. doneset := make(map[int]bool)
  242. timeout := time.After(2 * time.Second)
  243. for len(doneset) < len(alltasks) {
  244. select {
  245. case tt := <-done:
  246. if doneset[tt.index] {
  247. t.Errorf("task %d got done more than once", tt.index)
  248. } else {
  249. doneset[tt.index] = true
  250. }
  251. case <-timeout:
  252. t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
  253. for i := 0; i < len(alltasks); i++ {
  254. if !doneset[i] {
  255. t.Logf("task %d not done", i)
  256. }
  257. }
  258. return
  259. }
  260. }
  261. }
  262. type taskgen struct {
  263. newFunc func(running int, peers map[discover.NodeID]*Peer) []task
  264. doneFunc func(task)
  265. }
  266. func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
  267. return tg.newFunc(running, peers)
  268. }
  269. func (tg taskgen) taskDone(t task, now time.Time) {
  270. tg.doneFunc(t)
  271. }
  272. func (tg taskgen) addStatic(*discover.Node) {
  273. }
  274. func (tg taskgen) removeStatic(*discover.Node) {
  275. }
  276. type testTask struct {
  277. index int
  278. called bool
  279. }
  280. func (t *testTask) Do(srv *Server) {
  281. t.called = true
  282. }
  283. // This test checks that connections are disconnected
  284. // just after the encryption handshake when the server is
  285. // at capacity. Trusted connections should still be accepted.
  286. func TestServerAtCap(t *testing.T) {
  287. trustedID := randomID()
  288. srv := &Server{
  289. Config: Config{
  290. PrivateKey: newkey(),
  291. MaxPeers: 10,
  292. NoDial: true,
  293. TrustedNodes: []*discover.Node{{ID: trustedID}},
  294. },
  295. }
  296. if err := srv.Start(); err != nil {
  297. t.Fatalf("could not start: %v", err)
  298. }
  299. defer srv.Stop()
  300. newconn := func(id discover.NodeID) *conn {
  301. fd, _ := net.Pipe()
  302. tx := newTestTransport(id, fd)
  303. return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
  304. }
  305. // Inject a few connections to fill up the peer set.
  306. for i := 0; i < 10; i++ {
  307. c := newconn(randomID())
  308. if err := srv.checkpoint(c, srv.addpeer); err != nil {
  309. t.Fatalf("could not add conn %d: %v", i, err)
  310. }
  311. }
  312. // Try inserting a non-trusted connection.
  313. c := newconn(randomID())
  314. if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
  315. t.Error("wrong error for insert:", err)
  316. }
  317. // Try inserting a trusted connection.
  318. c = newconn(trustedID)
  319. if err := srv.checkpoint(c, srv.posthandshake); err != nil {
  320. t.Error("unexpected error for trusted conn @posthandshake:", err)
  321. }
  322. if !c.is(trustedConn) {
  323. t.Error("Server did not set trusted flag")
  324. }
  325. }
  326. func TestServerSetupConn(t *testing.T) {
  327. id := randomID()
  328. srvkey := newkey()
  329. srvid := discover.PubkeyID(&srvkey.PublicKey)
  330. tests := []struct {
  331. dontstart bool
  332. tt *setupTransport
  333. flags connFlag
  334. dialDest *discover.Node
  335. wantCloseErr error
  336. wantCalls string
  337. }{
  338. {
  339. dontstart: true,
  340. tt: &setupTransport{id: id},
  341. wantCalls: "close,",
  342. wantCloseErr: errServerStopped,
  343. },
  344. {
  345. tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
  346. flags: inboundConn,
  347. wantCalls: "doEncHandshake,close,",
  348. wantCloseErr: errors.New("read error"),
  349. },
  350. {
  351. tt: &setupTransport{id: id},
  352. dialDest: &discover.Node{ID: randomID()},
  353. flags: dynDialedConn,
  354. wantCalls: "doEncHandshake,close,",
  355. wantCloseErr: DiscUnexpectedIdentity,
  356. },
  357. {
  358. tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
  359. dialDest: &discover.Node{ID: id},
  360. flags: dynDialedConn,
  361. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  362. wantCloseErr: DiscUnexpectedIdentity,
  363. },
  364. {
  365. tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
  366. dialDest: &discover.Node{ID: id},
  367. flags: dynDialedConn,
  368. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  369. wantCloseErr: errors.New("foo"),
  370. },
  371. {
  372. tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
  373. flags: inboundConn,
  374. wantCalls: "doEncHandshake,close,",
  375. wantCloseErr: DiscSelf,
  376. },
  377. {
  378. tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
  379. flags: inboundConn,
  380. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  381. wantCloseErr: DiscUselessPeer,
  382. },
  383. }
  384. for i, test := range tests {
  385. srv := &Server{
  386. Config: Config{
  387. PrivateKey: srvkey,
  388. MaxPeers: 10,
  389. NoDial: true,
  390. Protocols: []Protocol{discard},
  391. },
  392. newTransport: func(fd net.Conn) transport { return test.tt },
  393. }
  394. if !test.dontstart {
  395. if err := srv.Start(); err != nil {
  396. t.Fatalf("couldn't start server: %v", err)
  397. }
  398. }
  399. p1, _ := net.Pipe()
  400. srv.setupConn(p1, test.flags, test.dialDest)
  401. if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
  402. t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
  403. }
  404. if test.tt.calls != test.wantCalls {
  405. t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
  406. }
  407. }
  408. }
  409. type setupTransport struct {
  410. id discover.NodeID
  411. encHandshakeErr error
  412. phs *protoHandshake
  413. protoHandshakeErr error
  414. calls string
  415. closeErr error
  416. }
  417. func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  418. c.calls += "doEncHandshake,"
  419. return c.id, c.encHandshakeErr
  420. }
  421. func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  422. c.calls += "doProtoHandshake,"
  423. if c.protoHandshakeErr != nil {
  424. return nil, c.protoHandshakeErr
  425. }
  426. return c.phs, nil
  427. }
  428. func (c *setupTransport) close(err error) {
  429. c.calls += "close,"
  430. c.closeErr = err
  431. }
  432. // setupConn shouldn't write to/read from the connection.
  433. func (c *setupTransport) WriteMsg(Msg) error {
  434. panic("WriteMsg called on setupTransport")
  435. }
  436. func (c *setupTransport) ReadMsg() (Msg, error) {
  437. panic("ReadMsg called on setupTransport")
  438. }
  439. func newkey() *ecdsa.PrivateKey {
  440. key, err := crypto.GenerateKey()
  441. if err != nil {
  442. panic("couldn't generate key: " + err.Error())
  443. }
  444. return key
  445. }
  446. func randomID() (id discover.NodeID) {
  447. for i := range id {
  448. id[i] = byte(rand.Intn(255))
  449. }
  450. return id
  451. }