server_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. // Copyright 2014 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package p2p
  17. import (
  18. "crypto/ecdsa"
  19. "errors"
  20. "math/rand"
  21. "net"
  22. "reflect"
  23. "testing"
  24. "time"
  25. "github.com/ethereum/go-ethereum/crypto"
  26. "github.com/ethereum/go-ethereum/crypto/sha3"
  27. "github.com/ethereum/go-ethereum/p2p/discover"
  28. )
  29. func init() {
  30. // glog.SetV(6)
  31. // glog.SetToStderr(true)
  32. }
  33. type testTransport struct {
  34. id discover.NodeID
  35. *rlpx
  36. closeErr error
  37. }
  38. func newTestTransport(id discover.NodeID, fd net.Conn) transport {
  39. wrapped := newRLPX(fd).(*rlpx)
  40. wrapped.rw = newRLPXFrameRW(fd, secrets{
  41. MAC: zero16,
  42. AES: zero16,
  43. IngressMAC: sha3.NewKeccak256(),
  44. EgressMAC: sha3.NewKeccak256(),
  45. })
  46. return &testTransport{id: id, rlpx: wrapped}
  47. }
  48. func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  49. return c.id, nil
  50. }
  51. func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  52. return &protoHandshake{ID: c.id, Name: "test"}, nil
  53. }
  54. func (c *testTransport) close(err error) {
  55. c.rlpx.fd.Close()
  56. c.closeErr = err
  57. }
  58. func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
  59. config := Config{
  60. Name: "test",
  61. MaxPeers: 10,
  62. ListenAddr: "127.0.0.1:0",
  63. PrivateKey: newkey(),
  64. }
  65. server := &Server{
  66. Config: config,
  67. newPeerHook: pf,
  68. newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
  69. }
  70. if err := server.Start(); err != nil {
  71. t.Fatalf("Could not start server: %v", err)
  72. }
  73. return server
  74. }
  75. func TestServerListen(t *testing.T) {
  76. // start the test server
  77. connected := make(chan *Peer)
  78. remid := randomID()
  79. srv := startTestServer(t, remid, func(p *Peer) {
  80. if p.ID() != remid {
  81. t.Error("peer func called with wrong node id")
  82. }
  83. if p == nil {
  84. t.Error("peer func called with nil conn")
  85. }
  86. connected <- p
  87. })
  88. defer close(connected)
  89. defer srv.Stop()
  90. // dial the test server
  91. conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
  92. if err != nil {
  93. t.Fatalf("could not dial: %v", err)
  94. }
  95. defer conn.Close()
  96. select {
  97. case peer := <-connected:
  98. if peer.LocalAddr().String() != conn.RemoteAddr().String() {
  99. t.Errorf("peer started with wrong conn: got %v, want %v",
  100. peer.LocalAddr(), conn.RemoteAddr())
  101. }
  102. peers := srv.Peers()
  103. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  104. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  105. }
  106. case <-time.After(1 * time.Second):
  107. t.Error("server did not accept within one second")
  108. }
  109. }
  110. func TestServerDial(t *testing.T) {
  111. // run a one-shot TCP server to handle the connection.
  112. listener, err := net.Listen("tcp", "127.0.0.1:0")
  113. if err != nil {
  114. t.Fatalf("could not setup listener: %v", err)
  115. }
  116. defer listener.Close()
  117. accepted := make(chan net.Conn)
  118. go func() {
  119. conn, err := listener.Accept()
  120. if err != nil {
  121. t.Error("accept error:", err)
  122. return
  123. }
  124. accepted <- conn
  125. }()
  126. // start the server
  127. connected := make(chan *Peer)
  128. remid := randomID()
  129. srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
  130. defer close(connected)
  131. defer srv.Stop()
  132. // tell the server to connect
  133. tcpAddr := listener.Addr().(*net.TCPAddr)
  134. srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
  135. select {
  136. case conn := <-accepted:
  137. defer conn.Close()
  138. select {
  139. case peer := <-connected:
  140. if peer.ID() != remid {
  141. t.Errorf("peer has wrong id")
  142. }
  143. if peer.Name() != "test" {
  144. t.Errorf("peer has wrong name")
  145. }
  146. if peer.RemoteAddr().String() != conn.LocalAddr().String() {
  147. t.Errorf("peer started with wrong conn: got %v, want %v",
  148. peer.RemoteAddr(), conn.LocalAddr())
  149. }
  150. peers := srv.Peers()
  151. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  152. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  153. }
  154. case <-time.After(1 * time.Second):
  155. t.Error("server did not launch peer within one second")
  156. }
  157. case <-time.After(1 * time.Second):
  158. t.Error("server did not connect within one second")
  159. }
  160. }
  161. // This test checks that tasks generated by dialstate are
  162. // actually executed and taskdone is called for them.
  163. func TestServerTaskScheduling(t *testing.T) {
  164. var (
  165. done = make(chan *testTask)
  166. quit, returned = make(chan struct{}), make(chan struct{})
  167. tc = 0
  168. tg = taskgen{
  169. newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
  170. tc++
  171. return []task{&testTask{index: tc - 1}}
  172. },
  173. doneFunc: func(t task) {
  174. select {
  175. case done <- t.(*testTask):
  176. case <-quit:
  177. }
  178. },
  179. }
  180. )
  181. // The Server in this test isn't actually running
  182. // because we're only interested in what run does.
  183. srv := &Server{
  184. Config: Config{MaxPeers: 10},
  185. quit: make(chan struct{}),
  186. ntab: fakeTable{},
  187. running: true,
  188. }
  189. srv.loopWG.Add(1)
  190. go func() {
  191. srv.run(tg)
  192. close(returned)
  193. }()
  194. var gotdone []*testTask
  195. for i := 0; i < 100; i++ {
  196. gotdone = append(gotdone, <-done)
  197. }
  198. for i, task := range gotdone {
  199. if task.index != i {
  200. t.Errorf("task %d has wrong index, got %d", i, task.index)
  201. break
  202. }
  203. if !task.called {
  204. t.Errorf("task %d was not called", i)
  205. break
  206. }
  207. }
  208. close(quit)
  209. srv.Stop()
  210. select {
  211. case <-returned:
  212. case <-time.After(500 * time.Millisecond):
  213. t.Error("Server.run did not return within 500ms")
  214. }
  215. }
  216. // This test checks that Server doesn't drop tasks,
  217. // even if newTasks returns more than the maximum number of tasks.
  218. func TestServerManyTasks(t *testing.T) {
  219. alltasks := make([]task, 300)
  220. for i := range alltasks {
  221. alltasks[i] = &testTask{index: i}
  222. }
  223. var (
  224. srv = &Server{quit: make(chan struct{}), ntab: fakeTable{}, running: true}
  225. done = make(chan *testTask)
  226. start, end = 0, 0
  227. )
  228. defer srv.Stop()
  229. srv.loopWG.Add(1)
  230. go srv.run(taskgen{
  231. newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
  232. start, end = end, end+maxActiveDialTasks+10
  233. if end > len(alltasks) {
  234. end = len(alltasks)
  235. }
  236. return alltasks[start:end]
  237. },
  238. doneFunc: func(tt task) {
  239. done <- tt.(*testTask)
  240. },
  241. })
  242. doneset := make(map[int]bool)
  243. timeout := time.After(2 * time.Second)
  244. for len(doneset) < len(alltasks) {
  245. select {
  246. case tt := <-done:
  247. if doneset[tt.index] {
  248. t.Errorf("task %d got done more than once", tt.index)
  249. } else {
  250. doneset[tt.index] = true
  251. }
  252. case <-timeout:
  253. t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
  254. for i := 0; i < len(alltasks); i++ {
  255. if !doneset[i] {
  256. t.Logf("task %d not done", i)
  257. }
  258. }
  259. return
  260. }
  261. }
  262. }
  263. type taskgen struct {
  264. newFunc func(running int, peers map[discover.NodeID]*Peer) []task
  265. doneFunc func(task)
  266. }
  267. func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
  268. return tg.newFunc(running, peers)
  269. }
  270. func (tg taskgen) taskDone(t task, now time.Time) {
  271. tg.doneFunc(t)
  272. }
  273. func (tg taskgen) addStatic(*discover.Node) {
  274. }
  275. type testTask struct {
  276. index int
  277. called bool
  278. }
  279. func (t *testTask) Do(srv *Server) {
  280. t.called = true
  281. }
  282. // This test checks that connections are disconnected
  283. // just after the encryption handshake when the server is
  284. // at capacity. Trusted connections should still be accepted.
  285. func TestServerAtCap(t *testing.T) {
  286. trustedID := randomID()
  287. srv := &Server{
  288. Config: Config{
  289. PrivateKey: newkey(),
  290. MaxPeers: 10,
  291. NoDial: true,
  292. TrustedNodes: []*discover.Node{{ID: trustedID}},
  293. },
  294. }
  295. if err := srv.Start(); err != nil {
  296. t.Fatalf("could not start: %v", err)
  297. }
  298. defer srv.Stop()
  299. newconn := func(id discover.NodeID) *conn {
  300. fd, _ := net.Pipe()
  301. tx := newTestTransport(id, fd)
  302. return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
  303. }
  304. // Inject a few connections to fill up the peer set.
  305. for i := 0; i < 10; i++ {
  306. c := newconn(randomID())
  307. if err := srv.checkpoint(c, srv.addpeer); err != nil {
  308. t.Fatalf("could not add conn %d: %v", i, err)
  309. }
  310. }
  311. // Try inserting a non-trusted connection.
  312. c := newconn(randomID())
  313. if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
  314. t.Error("wrong error for insert:", err)
  315. }
  316. // Try inserting a trusted connection.
  317. c = newconn(trustedID)
  318. if err := srv.checkpoint(c, srv.posthandshake); err != nil {
  319. t.Error("unexpected error for trusted conn @posthandshake:", err)
  320. }
  321. if !c.is(trustedConn) {
  322. t.Error("Server did not set trusted flag")
  323. }
  324. }
  325. func TestServerSetupConn(t *testing.T) {
  326. id := randomID()
  327. srvkey := newkey()
  328. srvid := discover.PubkeyID(&srvkey.PublicKey)
  329. tests := []struct {
  330. dontstart bool
  331. tt *setupTransport
  332. flags connFlag
  333. dialDest *discover.Node
  334. wantCloseErr error
  335. wantCalls string
  336. }{
  337. {
  338. dontstart: true,
  339. tt: &setupTransport{id: id},
  340. wantCalls: "close,",
  341. wantCloseErr: errServerStopped,
  342. },
  343. {
  344. tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
  345. flags: inboundConn,
  346. wantCalls: "doEncHandshake,close,",
  347. wantCloseErr: errors.New("read error"),
  348. },
  349. {
  350. tt: &setupTransport{id: id},
  351. dialDest: &discover.Node{ID: randomID()},
  352. flags: dynDialedConn,
  353. wantCalls: "doEncHandshake,close,",
  354. wantCloseErr: DiscUnexpectedIdentity,
  355. },
  356. {
  357. tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
  358. dialDest: &discover.Node{ID: id},
  359. flags: dynDialedConn,
  360. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  361. wantCloseErr: DiscUnexpectedIdentity,
  362. },
  363. {
  364. tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
  365. dialDest: &discover.Node{ID: id},
  366. flags: dynDialedConn,
  367. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  368. wantCloseErr: errors.New("foo"),
  369. },
  370. {
  371. tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
  372. flags: inboundConn,
  373. wantCalls: "doEncHandshake,close,",
  374. wantCloseErr: DiscSelf,
  375. },
  376. {
  377. tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
  378. flags: inboundConn,
  379. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  380. wantCloseErr: DiscUselessPeer,
  381. },
  382. }
  383. for i, test := range tests {
  384. srv := &Server{
  385. Config: Config{
  386. PrivateKey: srvkey,
  387. MaxPeers: 10,
  388. NoDial: true,
  389. Protocols: []Protocol{discard},
  390. },
  391. newTransport: func(fd net.Conn) transport { return test.tt },
  392. }
  393. if !test.dontstart {
  394. if err := srv.Start(); err != nil {
  395. t.Fatalf("couldn't start server: %v", err)
  396. }
  397. }
  398. p1, _ := net.Pipe()
  399. srv.setupConn(p1, test.flags, test.dialDest)
  400. if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
  401. t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
  402. }
  403. if test.tt.calls != test.wantCalls {
  404. t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
  405. }
  406. }
  407. }
  408. type setupTransport struct {
  409. id discover.NodeID
  410. encHandshakeErr error
  411. phs *protoHandshake
  412. protoHandshakeErr error
  413. calls string
  414. closeErr error
  415. }
  416. func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  417. c.calls += "doEncHandshake,"
  418. return c.id, c.encHandshakeErr
  419. }
  420. func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  421. c.calls += "doProtoHandshake,"
  422. if c.protoHandshakeErr != nil {
  423. return nil, c.protoHandshakeErr
  424. }
  425. return c.phs, nil
  426. }
  427. func (c *setupTransport) close(err error) {
  428. c.calls += "close,"
  429. c.closeErr = err
  430. }
  431. // setupConn shouldn't write to/read from the connection.
  432. func (c *setupTransport) WriteMsg(Msg) error {
  433. panic("WriteMsg called on setupTransport")
  434. }
  435. func (c *setupTransport) ReadMsg() (Msg, error) {
  436. panic("ReadMsg called on setupTransport")
  437. }
  438. func newkey() *ecdsa.PrivateKey {
  439. key, err := crypto.GenerateKey()
  440. if err != nil {
  441. panic("couldn't generate key: " + err.Error())
  442. }
  443. return key
  444. }
  445. func randomID() (id discover.NodeID) {
  446. for i := range id {
  447. id[i] = byte(rand.Intn(255))
  448. }
  449. return id
  450. }