peer_test.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. package p2p
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/hex"
  6. "io"
  7. "io/ioutil"
  8. "net"
  9. "reflect"
  10. "testing"
  11. "time"
  12. )
  13. var discard = Protocol{
  14. Name: "discard",
  15. Length: 1,
  16. Run: func(p *Peer, rw MsgReadWriter) error {
  17. for {
  18. msg, err := rw.ReadMsg()
  19. if err != nil {
  20. return err
  21. }
  22. if err = msg.Discard(); err != nil {
  23. return err
  24. }
  25. }
  26. },
  27. }
  28. func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
  29. conn1, conn2 := net.Pipe()
  30. peer := newPeer(conn1, protos, nil)
  31. peer.ourID = &peerId{}
  32. peer.pubkeyHook = func(*peerAddr) error { return nil }
  33. errc := make(chan error, 1)
  34. go func() {
  35. _, err := peer.loop()
  36. errc <- err
  37. }()
  38. return conn2, peer, errc
  39. }
  40. func TestPeerProtoReadMsg(t *testing.T) {
  41. defer testlog(t).detach()
  42. done := make(chan struct{})
  43. proto := Protocol{
  44. Name: "a",
  45. Length: 5,
  46. Run: func(peer *Peer, rw MsgReadWriter) error {
  47. msg, err := rw.ReadMsg()
  48. if err != nil {
  49. t.Errorf("read error: %v", err)
  50. }
  51. if msg.Code != 2 {
  52. t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
  53. }
  54. data, err := ioutil.ReadAll(msg.Payload)
  55. if err != nil {
  56. t.Errorf("payload read error: %v", err)
  57. }
  58. expdata, _ := hex.DecodeString("0183303030")
  59. if !bytes.Equal(expdata, data) {
  60. t.Errorf("incorrect msg data %x", data)
  61. }
  62. close(done)
  63. return nil
  64. },
  65. }
  66. net, peer, errc := testPeer([]Protocol{proto})
  67. defer net.Close()
  68. peer.startSubprotocols([]Cap{proto.cap()})
  69. writeMsg(net, NewMsg(18, 1, "000"))
  70. select {
  71. case <-done:
  72. case err := <-errc:
  73. t.Errorf("peer returned: %v", err)
  74. case <-time.After(2 * time.Second):
  75. t.Errorf("receive timeout")
  76. }
  77. }
  78. func TestPeerProtoReadLargeMsg(t *testing.T) {
  79. defer testlog(t).detach()
  80. msgsize := uint32(10 * 1024 * 1024)
  81. done := make(chan struct{})
  82. proto := Protocol{
  83. Name: "a",
  84. Length: 5,
  85. Run: func(peer *Peer, rw MsgReadWriter) error {
  86. msg, err := rw.ReadMsg()
  87. if err != nil {
  88. t.Errorf("read error: %v", err)
  89. }
  90. if msg.Size != msgsize+4 {
  91. t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
  92. }
  93. msg.Discard()
  94. close(done)
  95. return nil
  96. },
  97. }
  98. net, peer, errc := testPeer([]Protocol{proto})
  99. defer net.Close()
  100. peer.startSubprotocols([]Cap{proto.cap()})
  101. writeMsg(net, NewMsg(18, make([]byte, msgsize)))
  102. select {
  103. case <-done:
  104. case err := <-errc:
  105. t.Errorf("peer returned: %v", err)
  106. case <-time.After(2 * time.Second):
  107. t.Errorf("receive timeout")
  108. }
  109. }
  110. func TestPeerProtoEncodeMsg(t *testing.T) {
  111. defer testlog(t).detach()
  112. proto := Protocol{
  113. Name: "a",
  114. Length: 2,
  115. Run: func(peer *Peer, rw MsgReadWriter) error {
  116. if err := EncodeMsg(rw, 2); err == nil {
  117. t.Error("expected error for out-of-range msg code, got nil")
  118. }
  119. if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
  120. t.Errorf("write error: %v", err)
  121. }
  122. return nil
  123. },
  124. }
  125. net, peer, _ := testPeer([]Protocol{proto})
  126. defer net.Close()
  127. peer.startSubprotocols([]Cap{proto.cap()})
  128. bufr := bufio.NewReader(net)
  129. msg, err := readMsg(bufr)
  130. if err != nil {
  131. t.Errorf("read error: %v", err)
  132. }
  133. if msg.Code != 17 {
  134. t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
  135. }
  136. var data []string
  137. if err := msg.Decode(&data); err != nil {
  138. t.Errorf("payload decode error: %v", err)
  139. }
  140. if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
  141. t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
  142. }
  143. }
  144. func TestPeerWrite(t *testing.T) {
  145. defer testlog(t).detach()
  146. net, peer, peerErr := testPeer([]Protocol{discard})
  147. defer net.Close()
  148. peer.startSubprotocols([]Cap{discard.cap()})
  149. // test write errors
  150. if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
  151. t.Errorf("expected error for unknown protocol, got nil")
  152. }
  153. if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
  154. t.Errorf("expected error for out-of-range msg code, got nil")
  155. } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
  156. t.Errorf("wrong error for out-of-range msg code, got %#v", err)
  157. }
  158. // setup for reading the message on the other end
  159. read := make(chan struct{})
  160. go func() {
  161. bufr := bufio.NewReader(net)
  162. msg, err := readMsg(bufr)
  163. if err != nil {
  164. t.Errorf("read error: %v", err)
  165. } else if msg.Code != 16 {
  166. t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
  167. }
  168. msg.Discard()
  169. close(read)
  170. }()
  171. // test succcessful write
  172. if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
  173. t.Errorf("expect no error for known protocol: %v", err)
  174. }
  175. select {
  176. case <-read:
  177. case err := <-peerErr:
  178. t.Fatalf("peer stopped: %v", err)
  179. }
  180. }
  181. func TestPeerActivity(t *testing.T) {
  182. // shorten inactivityTimeout while this test is running
  183. oldT := inactivityTimeout
  184. defer func() { inactivityTimeout = oldT }()
  185. inactivityTimeout = 20 * time.Millisecond
  186. net, peer, peerErr := testPeer([]Protocol{discard})
  187. defer net.Close()
  188. peer.startSubprotocols([]Cap{discard.cap()})
  189. sub := peer.activity.Subscribe(time.Time{})
  190. defer sub.Unsubscribe()
  191. for i := 0; i < 6; i++ {
  192. writeMsg(net, NewMsg(16))
  193. select {
  194. case <-sub.Chan():
  195. case <-time.After(inactivityTimeout / 2):
  196. t.Fatal("no event within ", inactivityTimeout/2)
  197. case err := <-peerErr:
  198. t.Fatal("peer error", err)
  199. }
  200. }
  201. select {
  202. case <-time.After(inactivityTimeout * 2):
  203. case <-sub.Chan():
  204. t.Fatal("got activity event while connection was inactive")
  205. case err := <-peerErr:
  206. t.Fatal("peer error", err)
  207. }
  208. }
  209. func TestNewPeer(t *testing.T) {
  210. caps := []Cap{{"foo", 2}, {"bar", 3}}
  211. id := &peerId{}
  212. p := NewPeer(id, caps)
  213. if !reflect.DeepEqual(p.Caps(), caps) {
  214. t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
  215. }
  216. if p.Identity() != id {
  217. t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
  218. }
  219. // Should not hang.
  220. p.Disconnect(DiscAlreadyConnected)
  221. }
  222. func TestEOFSignal(t *testing.T) {
  223. rb := make([]byte, 10)
  224. // empty reader
  225. eof := make(chan struct{}, 1)
  226. sig := &eofSignal{new(bytes.Buffer), 0, eof}
  227. if n, err := sig.Read(rb); n != 0 || err != io.EOF {
  228. t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
  229. }
  230. select {
  231. case <-eof:
  232. default:
  233. t.Error("EOF chan not signaled")
  234. }
  235. // count before error
  236. eof = make(chan struct{}, 1)
  237. sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
  238. if n, err := sig.Read(rb); n != 8 || err != nil {
  239. t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
  240. }
  241. select {
  242. case <-eof:
  243. default:
  244. t.Error("EOF chan not signaled")
  245. }
  246. // error before count
  247. eof = make(chan struct{}, 1)
  248. sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
  249. if n, err := sig.Read(rb); n != 4 || err != nil {
  250. t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
  251. }
  252. if n, err := sig.Read(rb); n != 0 || err != io.EOF {
  253. t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
  254. }
  255. select {
  256. case <-eof:
  257. default:
  258. t.Error("EOF chan not signaled")
  259. }
  260. // no signal if neither occurs
  261. eof = make(chan struct{}, 1)
  262. sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
  263. if n, err := sig.Read(rb); n != 10 || err != nil {
  264. t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
  265. }
  266. select {
  267. case <-eof:
  268. t.Error("unexpected EOF signal")
  269. default:
  270. }
  271. }