suite.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. package ethtest
  2. import (
  3. "crypto/ecdsa"
  4. "fmt"
  5. "net"
  6. "reflect"
  7. "time"
  8. "github.com/ethereum/go-ethereum/core/types"
  9. "github.com/ethereum/go-ethereum/crypto"
  10. "github.com/ethereum/go-ethereum/internal/utesting"
  11. "github.com/ethereum/go-ethereum/p2p"
  12. "github.com/ethereum/go-ethereum/p2p/enode"
  13. "github.com/ethereum/go-ethereum/p2p/rlpx"
  14. "github.com/ethereum/go-ethereum/rlp"
  15. "github.com/stretchr/testify/assert"
  16. )
  17. // Suite represents a structure used to test the eth
  18. // protocol of a node(s).
  19. type Suite struct {
  20. Dest *enode.Node
  21. chain *Chain
  22. fullChain *Chain
  23. }
  24. type Conn struct {
  25. *rlpx.Conn
  26. ourKey *ecdsa.PrivateKey
  27. }
  28. func (c *Conn) Read() Message {
  29. code, rawData, _, err := c.Conn.Read()
  30. if err != nil {
  31. return &Error{fmt.Errorf("could not read from connection: %v", err)}
  32. }
  33. var msg Message
  34. switch int(code) {
  35. case (Hello{}).Code():
  36. msg = new(Hello)
  37. case (Disconnect{}).Code():
  38. msg = new(Disconnect)
  39. case (Status{}).Code():
  40. msg = new(Status)
  41. case (GetBlockHeaders{}).Code():
  42. msg = new(GetBlockHeaders)
  43. case (BlockHeaders{}).Code():
  44. msg = new(BlockHeaders)
  45. case (GetBlockBodies{}).Code():
  46. msg = new(GetBlockBodies)
  47. case (BlockBodies{}).Code():
  48. msg = new(BlockBodies)
  49. case (NewBlock{}).Code():
  50. msg = new(NewBlock)
  51. case (NewBlockHashes{}).Code():
  52. msg = new(NewBlockHashes)
  53. default:
  54. return &Error{fmt.Errorf("invalid message code: %d", code)}
  55. }
  56. if err := rlp.DecodeBytes(rawData, msg); err != nil {
  57. return &Error{fmt.Errorf("could not rlp decode message: %v", err)}
  58. }
  59. return msg
  60. }
  61. func (c *Conn) Write(msg Message) error {
  62. payload, err := rlp.EncodeToBytes(msg)
  63. if err != nil {
  64. return err
  65. }
  66. _, err = c.Conn.Write(uint64(msg.Code()), payload)
  67. return err
  68. }
  69. // handshake checks to make sure a `HELLO` is received.
  70. func (c *Conn) handshake(t *utesting.T) Message {
  71. // write protoHandshake to client
  72. pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
  73. ourHandshake := &Hello{
  74. Version: 3,
  75. Caps: []p2p.Cap{{Name: "eth", Version: 64}, {Name: "eth", Version: 65}},
  76. ID: pub0,
  77. }
  78. if err := c.Write(ourHandshake); err != nil {
  79. t.Fatalf("could not write to connection: %v", err)
  80. }
  81. // read protoHandshake from client
  82. switch msg := c.Read().(type) {
  83. case *Hello:
  84. return msg
  85. default:
  86. t.Fatalf("bad handshake: %v", msg)
  87. return nil
  88. }
  89. }
  90. // statusExchange performs a `Status` message exchange with the given
  91. // node.
  92. func (c *Conn) statusExchange(t *utesting.T, chain *Chain) Message {
  93. // read status message from client
  94. var message Message
  95. switch msg := c.Read().(type) {
  96. case *Status:
  97. if msg.Head != chain.blocks[chain.Len()-1].Hash() {
  98. t.Fatalf("wrong head in status: %v", msg.Head)
  99. }
  100. if msg.TD.Cmp(chain.TD(chain.Len())) != 0 {
  101. t.Fatalf("wrong TD in status: %v", msg.TD)
  102. }
  103. if !reflect.DeepEqual(msg.ForkID, chain.ForkID()) {
  104. t.Fatalf("wrong fork ID in status: %v", msg.ForkID)
  105. }
  106. message = msg
  107. default:
  108. t.Fatalf("bad status message: %v", msg)
  109. }
  110. // write status message to client
  111. status := Status{
  112. ProtocolVersion: 65,
  113. NetworkID: 1,
  114. TD: chain.TD(chain.Len()),
  115. Head: chain.blocks[chain.Len()-1].Hash(),
  116. Genesis: chain.blocks[0].Hash(),
  117. ForkID: chain.ForkID(),
  118. }
  119. if err := c.Write(status); err != nil {
  120. t.Fatalf("could not write to connection: %v", err)
  121. }
  122. return message
  123. }
  124. // waitForBlock waits for confirmation from the client that it has
  125. // imported the given block.
  126. func (c *Conn) waitForBlock(block *types.Block) error {
  127. for {
  128. req := &GetBlockHeaders{Origin: hashOrNumber{Hash: block.Hash()}, Amount: 1}
  129. if err := c.Write(req); err != nil {
  130. return err
  131. }
  132. switch msg := c.Read().(type) {
  133. case *BlockHeaders:
  134. if len(*msg) > 0 {
  135. return nil
  136. }
  137. time.Sleep(100 * time.Millisecond)
  138. default:
  139. return fmt.Errorf("invalid message: %v", msg)
  140. }
  141. }
  142. }
  143. // NewSuite creates and returns a new eth-test suite that can
  144. // be used to test the given node against the given blockchain
  145. // data.
  146. func NewSuite(dest *enode.Node, chainfile string, genesisfile string) *Suite {
  147. chain, err := loadChain(chainfile, genesisfile)
  148. if err != nil {
  149. panic(err)
  150. }
  151. return &Suite{
  152. Dest: dest,
  153. chain: chain.Shorten(1000),
  154. fullChain: chain,
  155. }
  156. }
  157. func (s *Suite) AllTests() []utesting.Test {
  158. return []utesting.Test{
  159. {Name: "Status", Fn: s.TestStatus},
  160. {Name: "GetBlockHeaders", Fn: s.TestGetBlockHeaders},
  161. {Name: "Broadcast", Fn: s.TestBroadcast},
  162. {Name: "GetBlockBodies", Fn: s.TestGetBlockBodies},
  163. }
  164. }
  165. // TestStatus attempts to connect to the given node and exchange
  166. // a status message with it, and then check to make sure
  167. // the chain head is correct.
  168. func (s *Suite) TestStatus(t *utesting.T) {
  169. conn, err := s.dial()
  170. if err != nil {
  171. t.Fatalf("could not dial: %v", err)
  172. }
  173. // get protoHandshake
  174. conn.handshake(t)
  175. // get status
  176. switch msg := conn.statusExchange(t, s.chain).(type) {
  177. case *Status:
  178. t.Logf("%+v\n", msg)
  179. default:
  180. t.Fatalf("error: %v", msg)
  181. }
  182. }
  183. // TestGetBlockHeaders tests whether the given node can respond to
  184. // a `GetBlockHeaders` request and that the response is accurate.
  185. func (s *Suite) TestGetBlockHeaders(t *utesting.T) {
  186. conn, err := s.dial()
  187. if err != nil {
  188. t.Fatalf("could not dial: %v", err)
  189. }
  190. conn.handshake(t)
  191. conn.statusExchange(t, s.chain)
  192. // get block headers
  193. req := &GetBlockHeaders{
  194. Origin: hashOrNumber{
  195. Hash: s.chain.blocks[1].Hash(),
  196. },
  197. Amount: 2,
  198. Skip: 1,
  199. Reverse: false,
  200. }
  201. if err := conn.Write(req); err != nil {
  202. t.Fatalf("could not write to connection: %v", err)
  203. }
  204. switch msg := conn.Read().(type) {
  205. case *BlockHeaders:
  206. headers := msg
  207. for _, header := range *headers {
  208. num := header.Number.Uint64()
  209. assert.Equal(t, s.chain.blocks[int(num)].Header(), header)
  210. t.Logf("\nHEADER FOR BLOCK NUMBER %d: %+v\n", header.Number, header)
  211. }
  212. default:
  213. t.Fatalf("error: %v", msg)
  214. }
  215. }
  216. // TestGetBlockBodies tests whether the given node can respond to
  217. // a `GetBlockBodies` request and that the response is accurate.
  218. func (s *Suite) TestGetBlockBodies(t *utesting.T) {
  219. conn, err := s.dial()
  220. if err != nil {
  221. t.Fatalf("could not dial: %v", err)
  222. }
  223. conn.handshake(t)
  224. conn.statusExchange(t, s.chain)
  225. // create block bodies request
  226. req := &GetBlockBodies{s.chain.blocks[54].Hash(), s.chain.blocks[75].Hash()}
  227. if err := conn.Write(req); err != nil {
  228. t.Fatalf("could not write to connection: %v", err)
  229. }
  230. switch msg := conn.Read().(type) {
  231. case *BlockBodies:
  232. bodies := msg
  233. for _, body := range *bodies {
  234. t.Logf("\nBODY: %+v\n", body)
  235. }
  236. default:
  237. t.Fatalf("error: %v", msg)
  238. }
  239. }
  240. // TestBroadcast tests whether a block announcement is correctly
  241. // propagated to the given node's peer(s).
  242. func (s *Suite) TestBroadcast(t *utesting.T) {
  243. // create conn to send block announcement
  244. sendConn, err := s.dial()
  245. if err != nil {
  246. t.Fatalf("could not dial: %v", err)
  247. }
  248. // create conn to receive block announcement
  249. receiveConn, err := s.dial()
  250. if err != nil {
  251. t.Fatalf("could not dial: %v", err)
  252. }
  253. sendConn.handshake(t)
  254. receiveConn.handshake(t)
  255. sendConn.statusExchange(t, s.chain)
  256. receiveConn.statusExchange(t, s.chain)
  257. // sendConn sends the block announcement
  258. blockAnnouncement := &NewBlock{
  259. Block: s.fullChain.blocks[1000],
  260. TD: s.fullChain.TD(1001),
  261. }
  262. if err := sendConn.Write(blockAnnouncement); err != nil {
  263. t.Fatalf("could not write to connection: %v", err)
  264. }
  265. switch msg := receiveConn.Read().(type) {
  266. case *NewBlock:
  267. assert.Equal(t, blockAnnouncement.Block.Header(), msg.Block.Header(),
  268. "wrong block header in announcement")
  269. assert.Equal(t, blockAnnouncement.TD, msg.TD,
  270. "wrong TD in announcement")
  271. case *NewBlockHashes:
  272. hashes := *msg
  273. assert.Equal(t, blockAnnouncement.Block.Hash(), hashes[0].Hash,
  274. "wrong block hash in announcement")
  275. default:
  276. t.Fatal(msg)
  277. }
  278. // update test suite chain
  279. s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[1000])
  280. // wait for client to update its chain
  281. if err := receiveConn.waitForBlock(s.chain.Head()); err != nil {
  282. t.Fatal(err)
  283. }
  284. }
  285. // dial attempts to dial the given node and perform a handshake,
  286. // returning the created Conn if successful.
  287. func (s *Suite) dial() (*Conn, error) {
  288. var conn Conn
  289. fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
  290. if err != nil {
  291. return nil, err
  292. }
  293. conn.Conn = rlpx.NewConn(fd, s.Dest.Pubkey())
  294. // do encHandshake
  295. conn.ourKey, _ = crypto.GenerateKey()
  296. _, err = conn.Handshake(conn.ourKey)
  297. if err != nil {
  298. return nil, err
  299. }
  300. return &conn, nil
  301. }