messenger_test.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package p2p
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "log"
  7. "net"
  8. "os"
  9. "reflect"
  10. "testing"
  11. "time"
  12. logpkg "github.com/ethereum/go-ethereum/logger"
  13. )
  14. func init() {
  15. logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
  16. }
  17. func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
  18. conn1, conn2 := net.Pipe()
  19. id := NewSimpleClientIdentity("test", "0", "0", "public key")
  20. server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
  21. peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
  22. return conn2, peer, peer.messenger
  23. }
  24. func performTestHandshake(r *bufio.Reader, w io.Writer) error {
  25. // read remote handshake
  26. msg, err := readMsg(r)
  27. if err != nil {
  28. return fmt.Errorf("read error: %v", err)
  29. }
  30. if msg.Code != handshakeMsg {
  31. return fmt.Errorf("first message should be handshake, got %d", msg.Code)
  32. }
  33. if err := msg.Discard(); err != nil {
  34. return err
  35. }
  36. // send empty handshake
  37. pubkey := make([]byte, 64)
  38. msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
  39. return writeMsg(w, msg)
  40. }
  41. type testProtocol struct {
  42. offset MsgCode
  43. f func(MsgReadWriter)
  44. }
  45. func (p *testProtocol) Offset() MsgCode {
  46. return p.offset
  47. }
  48. func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
  49. p.f(rw)
  50. return nil
  51. }
  52. func TestRead(t *testing.T) {
  53. done := make(chan struct{})
  54. handlers := Handlers{
  55. "a": &testProtocol{5, func(rw MsgReadWriter) {
  56. msg, err := rw.ReadMsg()
  57. if err != nil {
  58. t.Errorf("read error: %v", err)
  59. }
  60. if msg.Code != 2 {
  61. t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
  62. }
  63. data, err := msg.Data()
  64. if err != nil {
  65. t.Errorf("data decoding error: %v", err)
  66. }
  67. expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
  68. if !reflect.DeepEqual(data.Slice(), expdata) {
  69. t.Errorf("incorrect msg data %#v", data.Slice())
  70. }
  71. close(done)
  72. }},
  73. }
  74. net, peer, m := testMessenger(handlers)
  75. defer peer.Stop()
  76. bufr := bufio.NewReader(net)
  77. if err := performTestHandshake(bufr, net); err != nil {
  78. t.Fatalf("handshake failed: %v", err)
  79. }
  80. m.setRemoteProtocols([]string{"a"})
  81. writeMsg(net, NewMsg(18, 1, "000"))
  82. select {
  83. case <-done:
  84. case <-time.After(2 * time.Second):
  85. t.Errorf("receive timeout")
  86. }
  87. }
  88. func TestWriteFromProto(t *testing.T) {
  89. handlers := Handlers{
  90. "a": &testProtocol{2, func(rw MsgReadWriter) {
  91. if err := rw.WriteMsg(NewMsg(2)); err == nil {
  92. t.Error("expected error for out-of-range msg code, got nil")
  93. }
  94. if err := rw.WriteMsg(NewMsg(1)); err != nil {
  95. t.Errorf("write error: %v", err)
  96. }
  97. }},
  98. }
  99. net, peer, mess := testMessenger(handlers)
  100. defer peer.Stop()
  101. bufr := bufio.NewReader(net)
  102. if err := performTestHandshake(bufr, net); err != nil {
  103. t.Fatalf("handshake failed: %v", err)
  104. }
  105. mess.setRemoteProtocols([]string{"a"})
  106. msg, err := readMsg(bufr)
  107. if err != nil {
  108. t.Errorf("read error: %v")
  109. }
  110. if msg.Code != 17 {
  111. t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
  112. }
  113. }
  114. var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
  115. for {
  116. msg, err := rw.ReadMsg()
  117. if err != nil {
  118. return
  119. }
  120. if err = msg.Discard(); err != nil {
  121. return
  122. }
  123. }
  124. }}
  125. func TestMessengerWriteProtoMsg(t *testing.T) {
  126. handlers := Handlers{"a": discardProto}
  127. net, peer, mess := testMessenger(handlers)
  128. defer peer.Stop()
  129. bufr := bufio.NewReader(net)
  130. if err := performTestHandshake(bufr, net); err != nil {
  131. t.Fatalf("handshake failed: %v", err)
  132. }
  133. mess.setRemoteProtocols([]string{"a"})
  134. // test write errors
  135. if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
  136. t.Errorf("expected error for unknown protocol, got nil")
  137. }
  138. if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
  139. t.Errorf("expected error for out-of-range msg code, got nil")
  140. } else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
  141. t.Errorf("wrong error for out-of-range msg code, got %#v")
  142. }
  143. // test succcessful write
  144. read, readerr := make(chan Msg), make(chan error)
  145. go func() {
  146. if msg, err := readMsg(bufr); err != nil {
  147. readerr <- err
  148. } else {
  149. read <- msg
  150. }
  151. }()
  152. if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
  153. t.Errorf("expect no error for known protocol: %v", err)
  154. }
  155. select {
  156. case msg := <-read:
  157. if msg.Code != 16 {
  158. t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
  159. }
  160. msg.Discard()
  161. case err := <-readerr:
  162. t.Errorf("read error: %v", err)
  163. }
  164. }
  165. func TestPulse(t *testing.T) {
  166. net, peer, _ := testMessenger(nil)
  167. defer peer.Stop()
  168. bufr := bufio.NewReader(net)
  169. if err := performTestHandshake(bufr, net); err != nil {
  170. t.Fatalf("handshake failed: %v", err)
  171. }
  172. before := time.Now()
  173. msg, err := readMsg(bufr)
  174. if err != nil {
  175. t.Fatalf("read error: %v", err)
  176. }
  177. after := time.Now()
  178. if msg.Code != pingMsg {
  179. t.Errorf("expected ping message, got %d", msg.Code)
  180. }
  181. if d := after.Sub(before); d < pingTimeout {
  182. t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
  183. }
  184. }