server_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. // Copyright 2014 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package p2p
  17. import (
  18. "crypto/ecdsa"
  19. "errors"
  20. "math/rand"
  21. "net"
  22. "reflect"
  23. "testing"
  24. "time"
  25. "github.com/ethereum/go-ethereum/crypto"
  26. "github.com/ethereum/go-ethereum/crypto/sha3"
  27. "github.com/ethereum/go-ethereum/p2p/discover"
  28. )
  29. func init() {
  30. // glog.SetV(6)
  31. // glog.SetToStderr(true)
  32. }
  33. type testTransport struct {
  34. id discover.NodeID
  35. *rlpx
  36. closeErr error
  37. }
  38. func newTestTransport(id discover.NodeID, fd net.Conn) transport {
  39. wrapped := newRLPX(fd).(*rlpx)
  40. wrapped.rw = newRLPXFrameRW(fd, secrets{
  41. MAC: zero16,
  42. AES: zero16,
  43. IngressMAC: sha3.NewKeccak256(),
  44. EgressMAC: sha3.NewKeccak256(),
  45. })
  46. return &testTransport{id: id, rlpx: wrapped}
  47. }
  48. func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  49. return c.id, nil
  50. }
  51. func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  52. return &protoHandshake{ID: c.id, Name: "test"}, nil
  53. }
  54. func (c *testTransport) close(err error) {
  55. c.rlpx.fd.Close()
  56. c.closeErr = err
  57. }
  58. func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
  59. server := &Server{
  60. Name: "test",
  61. MaxPeers: 10,
  62. ListenAddr: "127.0.0.1:0",
  63. PrivateKey: newkey(),
  64. newPeerHook: pf,
  65. newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
  66. }
  67. if err := server.Start(); err != nil {
  68. t.Fatalf("Could not start server: %v", err)
  69. }
  70. return server
  71. }
  72. func TestServerListen(t *testing.T) {
  73. // start the test server
  74. connected := make(chan *Peer)
  75. remid := randomID()
  76. srv := startTestServer(t, remid, func(p *Peer) {
  77. if p.ID() != remid {
  78. t.Error("peer func called with wrong node id")
  79. }
  80. if p == nil {
  81. t.Error("peer func called with nil conn")
  82. }
  83. connected <- p
  84. })
  85. defer close(connected)
  86. defer srv.Stop()
  87. // dial the test server
  88. conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
  89. if err != nil {
  90. t.Fatalf("could not dial: %v", err)
  91. }
  92. defer conn.Close()
  93. select {
  94. case peer := <-connected:
  95. if peer.LocalAddr().String() != conn.RemoteAddr().String() {
  96. t.Errorf("peer started with wrong conn: got %v, want %v",
  97. peer.LocalAddr(), conn.RemoteAddr())
  98. }
  99. peers := srv.Peers()
  100. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  101. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  102. }
  103. case <-time.After(1 * time.Second):
  104. t.Error("server did not accept within one second")
  105. }
  106. }
  107. func TestServerDial(t *testing.T) {
  108. // run a one-shot TCP server to handle the connection.
  109. listener, err := net.Listen("tcp", "127.0.0.1:0")
  110. if err != nil {
  111. t.Fatalf("could not setup listener: %v", err)
  112. }
  113. defer listener.Close()
  114. accepted := make(chan net.Conn)
  115. go func() {
  116. conn, err := listener.Accept()
  117. if err != nil {
  118. t.Error("accept error:", err)
  119. return
  120. }
  121. accepted <- conn
  122. }()
  123. // start the server
  124. connected := make(chan *Peer)
  125. remid := randomID()
  126. srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
  127. defer close(connected)
  128. defer srv.Stop()
  129. // tell the server to connect
  130. tcpAddr := listener.Addr().(*net.TCPAddr)
  131. srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
  132. select {
  133. case conn := <-accepted:
  134. defer conn.Close()
  135. select {
  136. case peer := <-connected:
  137. if peer.ID() != remid {
  138. t.Errorf("peer has wrong id")
  139. }
  140. if peer.Name() != "test" {
  141. t.Errorf("peer has wrong name")
  142. }
  143. if peer.RemoteAddr().String() != conn.LocalAddr().String() {
  144. t.Errorf("peer started with wrong conn: got %v, want %v",
  145. peer.RemoteAddr(), conn.LocalAddr())
  146. }
  147. peers := srv.Peers()
  148. if !reflect.DeepEqual(peers, []*Peer{peer}) {
  149. t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
  150. }
  151. case <-time.After(1 * time.Second):
  152. t.Error("server did not launch peer within one second")
  153. }
  154. case <-time.After(1 * time.Second):
  155. t.Error("server did not connect within one second")
  156. }
  157. }
  158. // This test checks that tasks generated by dialstate are
  159. // actually executed and taskdone is called for them.
  160. func TestServerTaskScheduling(t *testing.T) {
  161. var (
  162. done = make(chan *testTask)
  163. quit, returned = make(chan struct{}), make(chan struct{})
  164. tc = 0
  165. tg = taskgen{
  166. newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
  167. tc++
  168. return []task{&testTask{index: tc - 1}}
  169. },
  170. doneFunc: func(t task) {
  171. select {
  172. case done <- t.(*testTask):
  173. case <-quit:
  174. }
  175. },
  176. }
  177. )
  178. // The Server in this test isn't actually running
  179. // because we're only interested in what run does.
  180. srv := &Server{
  181. MaxPeers: 10,
  182. quit: make(chan struct{}),
  183. ntab: fakeTable{},
  184. running: true,
  185. }
  186. srv.loopWG.Add(1)
  187. go func() {
  188. srv.run(tg)
  189. close(returned)
  190. }()
  191. var gotdone []*testTask
  192. for i := 0; i < 100; i++ {
  193. gotdone = append(gotdone, <-done)
  194. }
  195. for i, task := range gotdone {
  196. if task.index != i {
  197. t.Errorf("task %d has wrong index, got %d", i, task.index)
  198. break
  199. }
  200. if !task.called {
  201. t.Errorf("task %d was not called", i)
  202. break
  203. }
  204. }
  205. close(quit)
  206. srv.Stop()
  207. select {
  208. case <-returned:
  209. case <-time.After(500 * time.Millisecond):
  210. t.Error("Server.run did not return within 500ms")
  211. }
  212. }
  213. type taskgen struct {
  214. newFunc func(running int, peers map[discover.NodeID]*Peer) []task
  215. doneFunc func(task)
  216. }
  217. func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
  218. return tg.newFunc(running, peers)
  219. }
  220. func (tg taskgen) taskDone(t task, now time.Time) {
  221. tg.doneFunc(t)
  222. }
  223. func (tg taskgen) addStatic(*discover.Node) {
  224. }
  225. type testTask struct {
  226. index int
  227. called bool
  228. }
  229. func (t *testTask) Do(srv *Server) {
  230. t.called = true
  231. }
  232. // This test checks that connections are disconnected
  233. // just after the encryption handshake when the server is
  234. // at capacity. Trusted connections should still be accepted.
  235. func TestServerAtCap(t *testing.T) {
  236. trustedID := randomID()
  237. srv := &Server{
  238. PrivateKey: newkey(),
  239. MaxPeers: 10,
  240. NoDial: true,
  241. TrustedNodes: []*discover.Node{{ID: trustedID}},
  242. }
  243. if err := srv.Start(); err != nil {
  244. t.Fatalf("could not start: %v", err)
  245. }
  246. defer srv.Stop()
  247. newconn := func(id discover.NodeID) *conn {
  248. fd, _ := net.Pipe()
  249. tx := newTestTransport(id, fd)
  250. return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
  251. }
  252. // Inject a few connections to fill up the peer set.
  253. for i := 0; i < 10; i++ {
  254. c := newconn(randomID())
  255. if err := srv.checkpoint(c, srv.addpeer); err != nil {
  256. t.Fatalf("could not add conn %d: %v", i, err)
  257. }
  258. }
  259. // Try inserting a non-trusted connection.
  260. c := newconn(randomID())
  261. if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
  262. t.Error("wrong error for insert:", err)
  263. }
  264. // Try inserting a trusted connection.
  265. c = newconn(trustedID)
  266. if err := srv.checkpoint(c, srv.posthandshake); err != nil {
  267. t.Error("unexpected error for trusted conn @posthandshake:", err)
  268. }
  269. if !c.is(trustedConn) {
  270. t.Error("Server did not set trusted flag")
  271. }
  272. }
  273. func TestServerSetupConn(t *testing.T) {
  274. id := randomID()
  275. srvkey := newkey()
  276. srvid := discover.PubkeyID(&srvkey.PublicKey)
  277. tests := []struct {
  278. dontstart bool
  279. tt *setupTransport
  280. flags connFlag
  281. dialDest *discover.Node
  282. wantCloseErr error
  283. wantCalls string
  284. }{
  285. {
  286. dontstart: true,
  287. tt: &setupTransport{id: id},
  288. wantCalls: "close,",
  289. wantCloseErr: errServerStopped,
  290. },
  291. {
  292. tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
  293. flags: inboundConn,
  294. wantCalls: "doEncHandshake,close,",
  295. wantCloseErr: errors.New("read error"),
  296. },
  297. {
  298. tt: &setupTransport{id: id},
  299. dialDest: &discover.Node{ID: randomID()},
  300. flags: dynDialedConn,
  301. wantCalls: "doEncHandshake,close,",
  302. wantCloseErr: DiscUnexpectedIdentity,
  303. },
  304. {
  305. tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
  306. dialDest: &discover.Node{ID: id},
  307. flags: dynDialedConn,
  308. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  309. wantCloseErr: DiscUnexpectedIdentity,
  310. },
  311. {
  312. tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
  313. dialDest: &discover.Node{ID: id},
  314. flags: dynDialedConn,
  315. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  316. wantCloseErr: errors.New("foo"),
  317. },
  318. {
  319. tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
  320. flags: inboundConn,
  321. wantCalls: "doEncHandshake,close,",
  322. wantCloseErr: DiscSelf,
  323. },
  324. {
  325. tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
  326. flags: inboundConn,
  327. wantCalls: "doEncHandshake,doProtoHandshake,close,",
  328. wantCloseErr: DiscUselessPeer,
  329. },
  330. }
  331. for i, test := range tests {
  332. srv := &Server{
  333. PrivateKey: srvkey,
  334. MaxPeers: 10,
  335. NoDial: true,
  336. Protocols: []Protocol{discard},
  337. newTransport: func(fd net.Conn) transport { return test.tt },
  338. }
  339. if !test.dontstart {
  340. if err := srv.Start(); err != nil {
  341. t.Fatalf("couldn't start server: %v", err)
  342. }
  343. }
  344. p1, _ := net.Pipe()
  345. srv.setupConn(p1, test.flags, test.dialDest)
  346. if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
  347. t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
  348. }
  349. if test.tt.calls != test.wantCalls {
  350. t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
  351. }
  352. }
  353. }
  354. type setupTransport struct {
  355. id discover.NodeID
  356. encHandshakeErr error
  357. phs *protoHandshake
  358. protoHandshakeErr error
  359. calls string
  360. closeErr error
  361. }
  362. func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
  363. c.calls += "doEncHandshake,"
  364. return c.id, c.encHandshakeErr
  365. }
  366. func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
  367. c.calls += "doProtoHandshake,"
  368. if c.protoHandshakeErr != nil {
  369. return nil, c.protoHandshakeErr
  370. }
  371. return c.phs, nil
  372. }
  373. func (c *setupTransport) close(err error) {
  374. c.calls += "close,"
  375. c.closeErr = err
  376. }
  377. // setupConn shouldn't write to/read from the connection.
  378. func (c *setupTransport) WriteMsg(Msg) error {
  379. panic("WriteMsg called on setupTransport")
  380. }
  381. func (c *setupTransport) ReadMsg() (Msg, error) {
  382. panic("ReadMsg called on setupTransport")
  383. }
  384. func newkey() *ecdsa.PrivateKey {
  385. key, err := crypto.GenerateKey()
  386. if err != nil {
  387. panic("couldn't generate key: " + err.Error())
  388. }
  389. return key
  390. }
  391. func randomID() (id discover.NodeID) {
  392. for i := range id {
  393. id[i] = byte(rand.Intn(255))
  394. }
  395. return id
  396. }