handshake_test.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package p2p
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "fmt"
  6. "net"
  7. "reflect"
  8. "testing"
  9. "time"
  10. "github.com/ethereum/go-ethereum/crypto"
  11. "github.com/ethereum/go-ethereum/crypto/ecies"
  12. "github.com/ethereum/go-ethereum/p2p/discover"
  13. )
  14. func TestSharedSecret(t *testing.T) {
  15. prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
  16. pub0 := &prv0.PublicKey
  17. prv1, _ := crypto.GenerateKey()
  18. pub1 := &prv1.PublicKey
  19. ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
  20. if err != nil {
  21. return
  22. }
  23. ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
  24. if err != nil {
  25. return
  26. }
  27. t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
  28. if !bytes.Equal(ss0, ss1) {
  29. t.Errorf("dont match :(")
  30. }
  31. }
  32. func TestEncHandshake(t *testing.T) {
  33. for i := 0; i < 20; i++ {
  34. start := time.Now()
  35. if err := testEncHandshake(nil); err != nil {
  36. t.Fatalf("i=%d %v", i, err)
  37. }
  38. t.Logf("(without token) %d %v\n", i+1, time.Since(start))
  39. }
  40. for i := 0; i < 20; i++ {
  41. tok := make([]byte, shaLen)
  42. rand.Reader.Read(tok)
  43. start := time.Now()
  44. if err := testEncHandshake(tok); err != nil {
  45. t.Fatalf("i=%d %v", i, err)
  46. }
  47. t.Logf("(with token) %d %v\n", i+1, time.Since(start))
  48. }
  49. }
  50. func testEncHandshake(token []byte) error {
  51. type result struct {
  52. side string
  53. s secrets
  54. err error
  55. }
  56. var (
  57. prv0, _ = crypto.GenerateKey()
  58. prv1, _ = crypto.GenerateKey()
  59. rw0, rw1 = net.Pipe()
  60. output = make(chan result)
  61. )
  62. go func() {
  63. r := result{side: "initiator"}
  64. defer func() { output <- r }()
  65. pub1s := discover.PubkeyID(&prv1.PublicKey)
  66. r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
  67. if r.err != nil {
  68. return
  69. }
  70. id1 := discover.PubkeyID(&prv1.PublicKey)
  71. if r.s.RemoteID != id1 {
  72. r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
  73. }
  74. }()
  75. go func() {
  76. r := result{side: "receiver"}
  77. defer func() { output <- r }()
  78. r.s, r.err = receiverEncHandshake(rw1, prv1, token)
  79. if r.err != nil {
  80. return
  81. }
  82. id0 := discover.PubkeyID(&prv0.PublicKey)
  83. if r.s.RemoteID != id0 {
  84. r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
  85. }
  86. }()
  87. // wait for results from both sides
  88. r1, r2 := <-output, <-output
  89. if r1.err != nil {
  90. return fmt.Errorf("%s side error: %v", r1.side, r1.err)
  91. }
  92. if r2.err != nil {
  93. return fmt.Errorf("%s side error: %v", r2.side, r2.err)
  94. }
  95. // don't compare remote node IDs
  96. r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
  97. // flip MACs on one of them so they compare equal
  98. r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
  99. if !reflect.DeepEqual(r1.s, r2.s) {
  100. return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
  101. }
  102. return nil
  103. }
  104. func TestSetupConn(t *testing.T) {
  105. prv0, _ := crypto.GenerateKey()
  106. prv1, _ := crypto.GenerateKey()
  107. node0 := &discover.Node{
  108. ID: discover.PubkeyID(&prv0.PublicKey),
  109. IP: net.IP{1, 2, 3, 4},
  110. TCP: 33,
  111. }
  112. node1 := &discover.Node{
  113. ID: discover.PubkeyID(&prv1.PublicKey),
  114. IP: net.IP{5, 6, 7, 8},
  115. TCP: 44,
  116. }
  117. hs0 := &protoHandshake{
  118. Version: baseProtocolVersion,
  119. ID: node0.ID,
  120. Caps: []Cap{{"a", 0}, {"b", 2}},
  121. }
  122. hs1 := &protoHandshake{
  123. Version: baseProtocolVersion,
  124. ID: node1.ID,
  125. Caps: []Cap{{"c", 1}, {"d", 3}},
  126. }
  127. fd0, fd1 := net.Pipe()
  128. done := make(chan struct{})
  129. keepalways := func(discover.NodeID) bool { return true }
  130. go func() {
  131. defer close(done)
  132. conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways)
  133. if err != nil {
  134. t.Errorf("outbound side error: %v", err)
  135. return
  136. }
  137. if conn0.ID != node1.ID {
  138. t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
  139. }
  140. if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
  141. t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
  142. }
  143. }()
  144. conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways)
  145. if err != nil {
  146. t.Fatalf("inbound side error: %v", err)
  147. }
  148. if conn1.ID != node0.ID {
  149. t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
  150. }
  151. if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
  152. t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
  153. }
  154. <-done
  155. }