handshake_test.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package p2p
  2. import (
  3. "bytes"
  4. "net"
  5. "reflect"
  6. "testing"
  7. "github.com/ethereum/go-ethereum/crypto"
  8. "github.com/ethereum/go-ethereum/crypto/ecies"
  9. "github.com/ethereum/go-ethereum/p2p/discover"
  10. )
  11. func TestPublicKeyEncoding(t *testing.T) {
  12. prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
  13. pub0 := &prv0.PublicKey
  14. pub0s := crypto.FromECDSAPub(pub0)
  15. pub1, err := importPublicKey(pub0s)
  16. if err != nil {
  17. t.Errorf("%v", err)
  18. }
  19. eciesPub1 := ecies.ImportECDSAPublic(pub1)
  20. if eciesPub1 == nil {
  21. t.Errorf("invalid ecdsa public key")
  22. }
  23. pub1s, err := exportPublicKey(pub1)
  24. if err != nil {
  25. t.Errorf("%v", err)
  26. }
  27. if len(pub1s) != 64 {
  28. t.Errorf("wrong length expect 64, got", len(pub1s))
  29. }
  30. pub2, err := importPublicKey(pub1s)
  31. if err != nil {
  32. t.Errorf("%v", err)
  33. }
  34. pub2s, err := exportPublicKey(pub2)
  35. if err != nil {
  36. t.Errorf("%v", err)
  37. }
  38. if !bytes.Equal(pub1s, pub2s) {
  39. t.Errorf("exports dont match")
  40. }
  41. pub2sEC := crypto.FromECDSAPub(pub2)
  42. if !bytes.Equal(pub0s, pub2sEC) {
  43. t.Errorf("exports dont match")
  44. }
  45. }
  46. func TestSharedSecret(t *testing.T) {
  47. prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
  48. pub0 := &prv0.PublicKey
  49. prv1, _ := crypto.GenerateKey()
  50. pub1 := &prv1.PublicKey
  51. ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
  52. if err != nil {
  53. return
  54. }
  55. ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
  56. if err != nil {
  57. return
  58. }
  59. t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
  60. if !bytes.Equal(ss0, ss1) {
  61. t.Errorf("dont match :(")
  62. }
  63. }
  64. func TestEncHandshake(t *testing.T) {
  65. defer testlog(t).detach()
  66. prv0, _ := crypto.GenerateKey()
  67. prv1, _ := crypto.GenerateKey()
  68. rw0, rw1 := net.Pipe()
  69. secrets := make(chan secrets)
  70. go func() {
  71. pub1s, _ := exportPublicKey(&prv1.PublicKey)
  72. s, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
  73. if err != nil {
  74. t.Errorf("outbound side error: %v", err)
  75. }
  76. id1 := discover.PubkeyID(&prv1.PublicKey)
  77. if s.RemoteID != id1 {
  78. t.Errorf("outbound side remote ID mismatch")
  79. }
  80. secrets <- s
  81. }()
  82. go func() {
  83. s, err := inboundEncHandshake(rw1, prv1, nil)
  84. if err != nil {
  85. t.Errorf("inbound side error: %v", err)
  86. }
  87. id0 := discover.PubkeyID(&prv0.PublicKey)
  88. if s.RemoteID != id0 {
  89. t.Errorf("inbound side remote ID mismatch")
  90. }
  91. secrets <- s
  92. }()
  93. // get computed secrets from both sides
  94. t1, t2 := <-secrets, <-secrets
  95. // don't compare remote node IDs
  96. t1.RemoteID, t2.RemoteID = discover.NodeID{}, discover.NodeID{}
  97. // flip MACs on one of them so they compare equal
  98. t1.EgressMAC, t1.IngressMAC = t1.IngressMAC, t1.EgressMAC
  99. if !reflect.DeepEqual(t1, t2) {
  100. t.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", t1, t2)
  101. }
  102. }
  103. func TestSetupConn(t *testing.T) {
  104. prv0, _ := crypto.GenerateKey()
  105. prv1, _ := crypto.GenerateKey()
  106. node0 := &discover.Node{
  107. ID: discover.PubkeyID(&prv0.PublicKey),
  108. IP: net.IP{1, 2, 3, 4},
  109. TCPPort: 33,
  110. }
  111. node1 := &discover.Node{
  112. ID: discover.PubkeyID(&prv1.PublicKey),
  113. IP: net.IP{5, 6, 7, 8},
  114. TCPPort: 44,
  115. }
  116. hs0 := &protoHandshake{
  117. Version: baseProtocolVersion,
  118. ID: node0.ID,
  119. Caps: []Cap{{"a", 0}, {"b", 2}},
  120. }
  121. hs1 := &protoHandshake{
  122. Version: baseProtocolVersion,
  123. ID: node1.ID,
  124. Caps: []Cap{{"c", 1}, {"d", 3}},
  125. }
  126. fd0, fd1 := net.Pipe()
  127. done := make(chan struct{})
  128. go func() {
  129. defer close(done)
  130. conn0, err := setupConn(fd0, prv0, hs0, node1)
  131. if err != nil {
  132. t.Errorf("outbound side error: %v", err)
  133. return
  134. }
  135. if conn0.ID != node1.ID {
  136. t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
  137. }
  138. if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
  139. t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
  140. }
  141. }()
  142. conn1, err := setupConn(fd1, prv1, hs1, nil)
  143. if err != nil {
  144. t.Fatalf("inbound side error: %v", err)
  145. }
  146. if conn1.ID != node0.ID {
  147. t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
  148. }
  149. if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
  150. t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
  151. }
  152. <-done
  153. }