rlpx_test.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. package p2p
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "net"
  9. "reflect"
  10. "strings"
  11. "sync"
  12. "testing"
  13. "time"
  14. "github.com/davecgh/go-spew/spew"
  15. "github.com/ethereum/go-ethereum/crypto"
  16. "github.com/ethereum/go-ethereum/crypto/ecies"
  17. "github.com/ethereum/go-ethereum/crypto/sha3"
  18. "github.com/ethereum/go-ethereum/p2p/discover"
  19. "github.com/ethereum/go-ethereum/rlp"
  20. )
  21. func TestSharedSecret(t *testing.T) {
  22. prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
  23. pub0 := &prv0.PublicKey
  24. prv1, _ := crypto.GenerateKey()
  25. pub1 := &prv1.PublicKey
  26. ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
  27. if err != nil {
  28. return
  29. }
  30. ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
  31. if err != nil {
  32. return
  33. }
  34. t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
  35. if !bytes.Equal(ss0, ss1) {
  36. t.Errorf("dont match :(")
  37. }
  38. }
  39. func TestEncHandshake(t *testing.T) {
  40. for i := 0; i < 10; i++ {
  41. start := time.Now()
  42. if err := testEncHandshake(nil); err != nil {
  43. t.Fatalf("i=%d %v", i, err)
  44. }
  45. t.Logf("(without token) %d %v\n", i+1, time.Since(start))
  46. }
  47. for i := 0; i < 10; i++ {
  48. tok := make([]byte, shaLen)
  49. rand.Reader.Read(tok)
  50. start := time.Now()
  51. if err := testEncHandshake(tok); err != nil {
  52. t.Fatalf("i=%d %v", i, err)
  53. }
  54. t.Logf("(with token) %d %v\n", i+1, time.Since(start))
  55. }
  56. }
  57. func testEncHandshake(token []byte) error {
  58. type result struct {
  59. side string
  60. id discover.NodeID
  61. err error
  62. }
  63. var (
  64. prv0, _ = crypto.GenerateKey()
  65. prv1, _ = crypto.GenerateKey()
  66. fd0, fd1 = net.Pipe()
  67. c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
  68. output = make(chan result)
  69. )
  70. go func() {
  71. r := result{side: "initiator"}
  72. defer func() { output <- r }()
  73. dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
  74. r.id, r.err = c0.doEncHandshake(prv0, dest)
  75. if r.err != nil {
  76. return
  77. }
  78. id1 := discover.PubkeyID(&prv1.PublicKey)
  79. if r.id != id1 {
  80. r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
  81. }
  82. }()
  83. go func() {
  84. r := result{side: "receiver"}
  85. defer func() { output <- r }()
  86. r.id, r.err = c1.doEncHandshake(prv1, nil)
  87. if r.err != nil {
  88. return
  89. }
  90. id0 := discover.PubkeyID(&prv0.PublicKey)
  91. if r.id != id0 {
  92. r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
  93. }
  94. }()
  95. // wait for results from both sides
  96. r1, r2 := <-output, <-output
  97. if r1.err != nil {
  98. return fmt.Errorf("%s side error: %v", r1.side, r1.err)
  99. }
  100. if r2.err != nil {
  101. return fmt.Errorf("%s side error: %v", r2.side, r2.err)
  102. }
  103. // compare derived secrets
  104. if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
  105. return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
  106. }
  107. if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
  108. return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
  109. }
  110. if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
  111. return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
  112. }
  113. if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
  114. return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
  115. }
  116. return nil
  117. }
  118. func TestProtocolHandshake(t *testing.T) {
  119. var (
  120. prv0, _ = crypto.GenerateKey()
  121. node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
  122. hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
  123. prv1, _ = crypto.GenerateKey()
  124. node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
  125. hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
  126. fd0, fd1 = net.Pipe()
  127. wg sync.WaitGroup
  128. )
  129. wg.Add(2)
  130. go func() {
  131. defer wg.Done()
  132. rlpx := newRLPX(fd0)
  133. remid, err := rlpx.doEncHandshake(prv0, node1)
  134. if err != nil {
  135. t.Errorf("dial side enc handshake failed: %v", err)
  136. return
  137. }
  138. if remid != node1.ID {
  139. t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
  140. return
  141. }
  142. phs, err := rlpx.doProtoHandshake(hs0)
  143. if err != nil {
  144. t.Errorf("dial side proto handshake error: %v", err)
  145. return
  146. }
  147. if !reflect.DeepEqual(phs, hs1) {
  148. t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
  149. return
  150. }
  151. rlpx.close(DiscQuitting)
  152. }()
  153. go func() {
  154. defer wg.Done()
  155. rlpx := newRLPX(fd1)
  156. remid, err := rlpx.doEncHandshake(prv1, nil)
  157. if err != nil {
  158. t.Errorf("listen side enc handshake failed: %v", err)
  159. return
  160. }
  161. if remid != node0.ID {
  162. t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
  163. return
  164. }
  165. phs, err := rlpx.doProtoHandshake(hs1)
  166. if err != nil {
  167. t.Errorf("listen side proto handshake error: %v", err)
  168. return
  169. }
  170. if !reflect.DeepEqual(phs, hs0) {
  171. t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
  172. return
  173. }
  174. if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
  175. t.Errorf("error receiving disconnect: %v", err)
  176. }
  177. }()
  178. wg.Wait()
  179. }
  180. func TestProtocolHandshakeErrors(t *testing.T) {
  181. our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
  182. id := randomID()
  183. tests := []struct {
  184. code uint64
  185. msg interface{}
  186. err error
  187. }{
  188. {
  189. code: discMsg,
  190. msg: []DiscReason{DiscQuitting},
  191. err: DiscQuitting,
  192. },
  193. {
  194. code: 0x989898,
  195. msg: []byte{1},
  196. err: errors.New("expected handshake, got 989898"),
  197. },
  198. {
  199. code: handshakeMsg,
  200. msg: make([]byte, baseProtocolMaxMsgSize+2),
  201. err: errors.New("message too big"),
  202. },
  203. {
  204. code: handshakeMsg,
  205. msg: []byte{1, 2, 3},
  206. err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
  207. },
  208. {
  209. code: handshakeMsg,
  210. msg: &protoHandshake{Version: 9944, ID: id},
  211. err: DiscIncompatibleVersion,
  212. },
  213. {
  214. code: handshakeMsg,
  215. msg: &protoHandshake{Version: 3},
  216. err: DiscInvalidIdentity,
  217. },
  218. }
  219. for i, test := range tests {
  220. p1, p2 := MsgPipe()
  221. go Send(p1, test.code, test.msg)
  222. _, err := readProtocolHandshake(p2, our)
  223. if !reflect.DeepEqual(err, test.err) {
  224. t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
  225. }
  226. }
  227. }
  228. func TestRLPXFrameFake(t *testing.T) {
  229. buf := new(bytes.Buffer)
  230. hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
  231. rw := newRLPXFrameRW(buf, secrets{
  232. AES: crypto.Sha3(),
  233. MAC: crypto.Sha3(),
  234. IngressMAC: hash,
  235. EgressMAC: hash,
  236. })
  237. golden := unhex(`
  238. 00828ddae471818bb0bfa6b551d1cb42
  239. 01010101010101010101010101010101
  240. ba628a4ba590cb43f7848f41c4382885
  241. 01010101010101010101010101010101
  242. `)
  243. // Check WriteMsg. This puts a message into the buffer.
  244. if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil {
  245. t.Fatalf("WriteMsg error: %v", err)
  246. }
  247. written := buf.Bytes()
  248. if !bytes.Equal(written, golden) {
  249. t.Fatalf("output mismatch:\n got: %x\n want: %x", written, golden)
  250. }
  251. // Check ReadMsg. It reads the message encoded by WriteMsg, which
  252. // is equivalent to the golden message above.
  253. msg, err := rw.ReadMsg()
  254. if err != nil {
  255. t.Fatalf("ReadMsg error: %v", err)
  256. }
  257. if msg.Size != 5 {
  258. t.Errorf("msg size mismatch: got %d, want %d", msg.Size, 5)
  259. }
  260. if msg.Code != 8 {
  261. t.Errorf("msg code mismatch: got %d, want %d", msg.Code, 8)
  262. }
  263. payload, _ := ioutil.ReadAll(msg.Payload)
  264. wantPayload := unhex("C401020304")
  265. if !bytes.Equal(payload, wantPayload) {
  266. t.Errorf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload)
  267. }
  268. }
  269. type fakeHash []byte
  270. func (fakeHash) Write(p []byte) (int, error) { return len(p), nil }
  271. func (fakeHash) Reset() {}
  272. func (fakeHash) BlockSize() int { return 0 }
  273. func (h fakeHash) Size() int { return len(h) }
  274. func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
  275. func TestRLPXFrameRW(t *testing.T) {
  276. var (
  277. aesSecret = make([]byte, 16)
  278. macSecret = make([]byte, 16)
  279. egressMACinit = make([]byte, 32)
  280. ingressMACinit = make([]byte, 32)
  281. )
  282. for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} {
  283. rand.Read(s)
  284. }
  285. conn := new(bytes.Buffer)
  286. s1 := secrets{
  287. AES: aesSecret,
  288. MAC: macSecret,
  289. EgressMAC: sha3.NewKeccak256(),
  290. IngressMAC: sha3.NewKeccak256(),
  291. }
  292. s1.EgressMAC.Write(egressMACinit)
  293. s1.IngressMAC.Write(ingressMACinit)
  294. rw1 := newRLPXFrameRW(conn, s1)
  295. s2 := secrets{
  296. AES: aesSecret,
  297. MAC: macSecret,
  298. EgressMAC: sha3.NewKeccak256(),
  299. IngressMAC: sha3.NewKeccak256(),
  300. }
  301. s2.EgressMAC.Write(ingressMACinit)
  302. s2.IngressMAC.Write(egressMACinit)
  303. rw2 := newRLPXFrameRW(conn, s2)
  304. // send some messages
  305. for i := 0; i < 10; i++ {
  306. // write message into conn buffer
  307. wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
  308. err := Send(rw1, uint64(i), wmsg)
  309. if err != nil {
  310. t.Fatalf("WriteMsg error (i=%d): %v", i, err)
  311. }
  312. // read message that rw1 just wrote
  313. msg, err := rw2.ReadMsg()
  314. if err != nil {
  315. t.Fatalf("ReadMsg error (i=%d): %v", i, err)
  316. }
  317. if msg.Code != uint64(i) {
  318. t.Fatalf("msg code mismatch: got %d, want %d", msg.Code, i)
  319. }
  320. payload, _ := ioutil.ReadAll(msg.Payload)
  321. wantPayload, _ := rlp.EncodeToBytes(wmsg)
  322. if !bytes.Equal(payload, wantPayload) {
  323. t.Fatalf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload)
  324. }
  325. }
  326. }