server_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. package p2p
  2. import (
  3. "crypto/ecdsa"
  4. "errors"
  5. "math/rand"
  6. "net"
  7. "reflect"
  8. "testing"
  9. "time"
  10. "github.com/ethereum/go-ethereum/crypto"
  11. "github.com/ethereum/go-ethereum/crypto/sha3"
  12. "github.com/ethereum/go-ethereum/p2p/discover"
  13. )
  14. func init() {
  15. // glog.SetV(6)
  16. // glog.SetToStderr(true)
  17. }
  18. type testTransport struct {
  19. id discover.NodeID
  20. *rlpx
  21. closeErr error
  22. }
  23. func newTestTransport(id discover.NodeID, fd net.Conn) transport {
  24. wrapped := newRLPX(fd).(*rlpx)
  25. wrapped.rw = newRLPXFrameRW(fd, secrets{
  26. MAC: zero16,
  27. AES: zero16,
  28. IngressMAC: sha3.NewKeccak256(),
  29. EgressMAC: sha3.NewKeccak256(),
  30. })
  31. return &testTransport{id: id, rlpx: wrapped}
  32. }
  33. func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  34. return c.id, nil
  35. }
  36. func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  37. return &protoHandshake{ID: c.id, Name: "test"}, nil
  38. }
  39. func (c *testTransport) close(err error) {
  40. c.rlpx.fd.Close()
  41. c.closeErr = err
  42. }
  43. func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
  44. server := &Server{
  45. Name: "test",
  46. MaxPeers: 10,
  47. ListenAddr: "127.0.0.1:0",
  48. PrivateKey: newkey(),
  49. newPeerHook: pf,
  50. newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
  51. }
  52. if err := server.Start(); err != nil {
  53. t.Fatalf("Could not start server: %v", err)
  54. }
  55. return server
  56. }
  57. func TestServerListen(t *testing.T) {
  58. // start the test server
  59. connected := make(chan *Peer)
  60. remid := randomID()
  61. srv := startTestServer(t, remid, func(p *Peer) {
  62. if p.ID() != remid {
  63. t.Error("peer func called with wrong node id")
  64. }
  65. if p == nil {
  66. t.Error("peer func called with nil conn")
  67. }
  68. connected <- p
  69. })
  70. defer close(connected)
  71. defer srv.Stop()
  72. // dial the test server
  73. conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
  74. if err != nil {
  75. t.Fatalf("could not dial: %v", err)
  76. }
  77. defer conn.Close()
  78. select {
  79. case peer := <-connected:
  80. if peer.LocalAddr().String() != conn.RemoteAddr().String() {
  81. t.Errorf("peer started with wrong conn: got %v, want %v",
  82. peer.LocalAddr(), conn.RemoteAddr())
  83. }
  84. peers := srv.Peers()
  85. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  86. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  87. }
  88. case <-time.After(1 * time.Second):
  89. t.Error("server did not accept within one second")
  90. }
  91. }
  92. func TestServerDial(t *testing.T) {
  93. // run a one-shot TCP server to handle the connection.
  94. listener, err := net.Listen("tcp", "127.0.0.1:0")
  95. if err != nil {
  96. t.Fatalf("could not setup listener: %v")
  97. }
  98. defer listener.Close()
  99. accepted := make(chan net.Conn)
  100. go func() {
  101. conn, err := listener.Accept()
  102. if err != nil {
  103. t.Error("accept error:", err)
  104. return
  105. }
  106. conn.Close()
  107. accepted <- conn
  108. }()
  109. // start the server
  110. connected := make(chan *Peer)
  111. remid := randomID()
  112. srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
  113. defer close(connected)
  114. defer srv.Stop()
  115. // tell the server to connect
  116. tcpAddr := listener.Addr().(*net.TCPAddr)
  117. srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
  118. select {
  119. case conn := <-accepted:
  120. select {
  121. case peer := <-connected:
  122. if peer.ID() != remid {
  123. t.Errorf("peer has wrong id")
  124. }
  125. if peer.Name() != "test" {
  126. t.Errorf("peer has wrong name")
  127. }
  128. if peer.RemoteAddr().String() != conn.LocalAddr().String() {
  129. t.Errorf("peer started with wrong conn: got %v, want %v",
  130. peer.RemoteAddr(), conn.LocalAddr())
  131. }
  132. peers := srv.Peers()
  133. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  134. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  135. }
  136. case <-time.After(1 * time.Second):
  137. t.Error("server did not launch peer within one second")
  138. }
  139. case <-time.After(1 * time.Second):
  140. t.Error("server did not connect within one second")
  141. }
  142. }
  143. // This test checks that tasks generated by dialstate are
  144. // actually executed and taskdone is called for them.
  145. func TestServerTaskScheduling(t *testing.T) {
  146. var (
  147. done = make(chan *testTask)
  148. quit, returned = make(chan struct{}), make(chan struct{})
  149. tc = 0
  150. tg = taskgen{
  151. newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
  152. tc++
  153. return []task{&testTask{index: tc - 1}}
  154. },
  155. doneFunc: func(t task) {
  156. select {
  157. case done <- t.(*testTask):
  158. case <-quit:
  159. }
  160. },
  161. }
  162. )
  163. // The Server in this test isn't actually running
  164. // because we're only interested in what run does.
  165. srv := &Server{
  166. MaxPeers: 10,
  167. quit: make(chan struct{}),
  168. ntab: fakeTable{},
  169. running: true,
  170. }
  171. srv.loopWG.Add(1)
  172. go func() {
  173. srv.run(tg)
  174. close(returned)
  175. }()
  176. var gotdone []*testTask
  177. for i := 0; i < 100; i++ {
  178. gotdone = append(gotdone, <-done)
  179. }
  180. for i, task := range gotdone {
  181. if task.index != i {
  182. t.Errorf("task %d has wrong index, got %d", i, task.index)
  183. break
  184. }
  185. if !task.called {
  186. t.Errorf("task %d was not called", i)
  187. break
  188. }
  189. }
  190. close(quit)
  191. srv.Stop()
  192. select {
  193. case <-returned:
  194. case <-time.After(500 * time.Millisecond):
  195. t.Error("Server.run did not return within 500ms")
  196. }
  197. }
  198. type taskgen struct {
  199. newFunc func(running int, peers map[discover.NodeID]*Peer) []task
  200. doneFunc func(task)
  201. }
  202. func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
  203. return tg.newFunc(running, peers)
  204. }
  205. func (tg taskgen) taskDone(t task, now time.Time) {
  206. tg.doneFunc(t)
  207. }
  208. func (tg taskgen) addStatic(*discover.Node) {
  209. }
  210. type testTask struct {
  211. index int
  212. called bool
  213. }
  214. func (t *testTask) Do(srv *Server) {
  215. t.called = true
  216. }
  217. // This test checks that connections are disconnected
  218. // just after the encryption handshake when the server is
  219. // at capacity. Trusted connections should still be accepted.
  220. func TestServerAtCap(t *testing.T) {
  221. trustedID := randomID()
  222. srv := &Server{
  223. PrivateKey: newkey(),
  224. MaxPeers: 10,
  225. NoDial: true,
  226. TrustedNodes: []*discover.Node{{ID: trustedID}},
  227. }
  228. if err := srv.Start(); err != nil {
  229. t.Fatalf("could not start: %v", err)
  230. }
  231. defer srv.Stop()
  232. newconn := func(id discover.NodeID) *conn {
  233. fd, _ := net.Pipe()
  234. tx := newTestTransport(id, fd)
  235. return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
  236. }
  237. // Inject a few connections to fill up the peer set.
  238. for i := 0; i < 10; i++ {
  239. c := newconn(randomID())
  240. if err := srv.checkpoint(c, srv.addpeer); err != nil {
  241. t.Fatalf("could not add conn %d: %v", i, err)
  242. }
  243. }
  244. // Try inserting a non-trusted connection.
  245. c := newconn(randomID())
  246. if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
  247. t.Error("wrong error for insert:", err)
  248. }
  249. // Try inserting a trusted connection.
  250. c = newconn(trustedID)
  251. if err := srv.checkpoint(c, srv.posthandshake); err != nil {
  252. t.Error("unexpected error for trusted conn @posthandshake:", err)
  253. }
  254. if !c.is(trustedConn) {
  255. t.Error("Server did not set trusted flag")
  256. }
  257. }
  258. func TestServerSetupConn(t *testing.T) {
  259. id := randomID()
  260. srvkey := newkey()
  261. srvid := discover.PubkeyID(&srvkey.PublicKey)
  262. tests := []struct {
  263. dontstart bool
  264. tt *setupTransport
  265. flags connFlag
  266. dialDest *discover.Node
  267. wantCloseErr error
  268. wantCalls string
  269. }{
  270. {
  271. dontstart: true,
  272. tt: &setupTransport{id: id},
  273. wantCalls: "close,",
  274. wantCloseErr: errServerStopped,
  275. },
  276. {
  277. tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
  278. flags: inboundConn,
  279. wantCalls: "doEncHandshake,close,",
  280. wantCloseErr: errors.New("read error"),
  281. },
  282. {
  283. tt: &setupTransport{id: id},
  284. dialDest: &discover.Node{ID: randomID()},
  285. flags: dynDialedConn,
  286. wantCalls: "doEncHandshake,close,",
  287. wantCloseErr: DiscUnexpectedIdentity,
  288. },
  289. {
  290. tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
  291. dialDest: &discover.Node{ID: id},
  292. flags: dynDialedConn,
  293. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  294. wantCloseErr: DiscUnexpectedIdentity,
  295. },
  296. {
  297. tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
  298. dialDest: &discover.Node{ID: id},
  299. flags: dynDialedConn,
  300. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  301. wantCloseErr: errors.New("foo"),
  302. },
  303. {
  304. tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
  305. flags: inboundConn,
  306. wantCalls: "doEncHandshake,close,",
  307. wantCloseErr: DiscSelf,
  308. },
  309. {
  310. tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
  311. flags: inboundConn,
  312. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  313. wantCloseErr: DiscUselessPeer,
  314. },
  315. }
  316. for i, test := range tests {
  317. srv := &Server{
  318. PrivateKey: srvkey,
  319. MaxPeers: 10,
  320. NoDial: true,
  321. Protocols: []Protocol{discard},
  322. newTransport: func(fd net.Conn) transport { return test.tt },
  323. }
  324. if !test.dontstart {
  325. if err := srv.Start(); err != nil {
  326. t.Fatalf("couldn't start server: %v", err)
  327. }
  328. }
  329. p1, _ := net.Pipe()
  330. srv.setupConn(p1, test.flags, test.dialDest)
  331. if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
  332. t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
  333. }
  334. if test.tt.calls != test.wantCalls {
  335. t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
  336. }
  337. }
  338. }
  339. type setupTransport struct {
  340. id discover.NodeID
  341. encHandshakeErr error
  342. phs *protoHandshake
  343. protoHandshakeErr error
  344. calls string
  345. closeErr error
  346. }
  347. func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  348. c.calls += "doEncHandshake,"
  349. return c.id, c.encHandshakeErr
  350. }
  351. func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  352. c.calls += "doProtoHandshake,"
  353. if c.protoHandshakeErr != nil {
  354. return nil, c.protoHandshakeErr
  355. }
  356. return c.phs, nil
  357. }
  358. func (c *setupTransport) close(err error) {
  359. c.calls += "close,"
  360. c.closeErr = err
  361. }
  362. // setupConn shouldn't write to/read from the connection.
  363. func (c *setupTransport) WriteMsg(Msg) error {
  364. panic("WriteMsg called on setupTransport")
  365. }
  366. func (c *setupTransport) ReadMsg() (Msg, error) {
  367. panic("ReadMsg called on setupTransport")
  368. }
  369. func newkey() *ecdsa.PrivateKey {
  370. key, err := crypto.GenerateKey()
  371. if err != nil {
  372. panic("couldn't generate key: " + err.Error())
  373. }
  374. return key
  375. }
  376. func randomID() (id discover.NodeID) {
  377. for i := range id {
  378. id[i] = byte(rand.Intn(255))
  379. }
  380. return id
  381. }