protocol_test.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. package eth
  2. import (
  3. "bytes"
  4. "io"
  5. "log"
  6. "math/big"
  7. "os"
  8. "testing"
  9. "time"
  10. "github.com/ethereum/go-ethereum/core/types"
  11. "github.com/ethereum/go-ethereum/crypto"
  12. "github.com/ethereum/go-ethereum/errs"
  13. "github.com/ethereum/go-ethereum/ethutil"
  14. ethlogger "github.com/ethereum/go-ethereum/logger"
  15. "github.com/ethereum/go-ethereum/p2p"
  16. "github.com/ethereum/go-ethereum/p2p/discover"
  17. )
  18. var logsys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel))
  19. var ini = false
  20. func logInit() {
  21. if !ini {
  22. ethlogger.AddLogSystem(logsys)
  23. ini = true
  24. }
  25. }
  26. type testMsgReadWriter struct {
  27. in chan p2p.Msg
  28. out []p2p.Msg
  29. }
  30. func (self *testMsgReadWriter) In(msg p2p.Msg) {
  31. self.in <- msg
  32. }
  33. func (self *testMsgReadWriter) Out() (msg p2p.Msg, ok bool) {
  34. if len(self.out) > 0 {
  35. msg = self.out[0]
  36. self.out = self.out[1:]
  37. ok = true
  38. }
  39. return
  40. }
  41. func (self *testMsgReadWriter) WriteMsg(msg p2p.Msg) error {
  42. self.out = append(self.out, msg)
  43. return nil
  44. }
  45. func (self *testMsgReadWriter) ReadMsg() (p2p.Msg, error) {
  46. msg, ok := <-self.in
  47. if !ok {
  48. return msg, io.EOF
  49. }
  50. return msg, nil
  51. }
  52. type testTxPool struct {
  53. getTransactions func() []*types.Transaction
  54. addTransactions func(txs []*types.Transaction)
  55. }
  56. type testChainManager struct {
  57. getBlockHashes func(hash []byte, amount uint64) (hashes [][]byte)
  58. getBlock func(hash []byte) *types.Block
  59. status func() (td *big.Int, currentBlock []byte, genesisBlock []byte)
  60. }
  61. type testBlockPool struct {
  62. addBlockHashes func(next func() ([]byte, bool), peerId string)
  63. addBlock func(block *types.Block, peerId string) (err error)
  64. addPeer func(td *big.Int, currentBlock []byte, peerId string, requestHashes func([]byte) error, requestBlocks func([][]byte) error, peerError func(*errs.Error)) (best bool)
  65. removePeer func(peerId string)
  66. }
  67. // func (self *testTxPool) GetTransactions() (txs []*types.Transaction) {
  68. // if self.getTransactions != nil {
  69. // txs = self.getTransactions()
  70. // }
  71. // return
  72. // }
  73. func (self *testTxPool) AddTransactions(txs []*types.Transaction) {
  74. if self.addTransactions != nil {
  75. self.addTransactions(txs)
  76. }
  77. }
  78. func (self *testTxPool) GetTransactions() types.Transactions { return nil }
  79. func (self *testChainManager) GetBlockHashesFromHash(hash []byte, amount uint64) (hashes [][]byte) {
  80. if self.getBlockHashes != nil {
  81. hashes = self.getBlockHashes(hash, amount)
  82. }
  83. return
  84. }
  85. func (self *testChainManager) Status() (td *big.Int, currentBlock []byte, genesisBlock []byte) {
  86. if self.status != nil {
  87. td, currentBlock, genesisBlock = self.status()
  88. }
  89. return
  90. }
  91. func (self *testChainManager) GetBlock(hash []byte) (block *types.Block) {
  92. if self.getBlock != nil {
  93. block = self.getBlock(hash)
  94. }
  95. return
  96. }
  97. func (self *testBlockPool) AddBlockHashes(next func() ([]byte, bool), peerId string) {
  98. if self.addBlockHashes != nil {
  99. self.addBlockHashes(next, peerId)
  100. }
  101. }
  102. func (self *testBlockPool) AddBlock(block *types.Block, peerId string) {
  103. if self.addBlock != nil {
  104. self.addBlock(block, peerId)
  105. }
  106. }
  107. func (self *testBlockPool) AddPeer(td *big.Int, currentBlock []byte, peerId string, requestBlockHashes func([]byte) error, requestBlocks func([][]byte) error, peerError func(*errs.Error)) (best bool) {
  108. if self.addPeer != nil {
  109. best = self.addPeer(td, currentBlock, peerId, requestBlockHashes, requestBlocks, peerError)
  110. }
  111. return
  112. }
  113. func (self *testBlockPool) RemovePeer(peerId string) {
  114. if self.removePeer != nil {
  115. self.removePeer(peerId)
  116. }
  117. }
  118. func testPeer() *p2p.Peer {
  119. var id discover.NodeID
  120. pk := crypto.GenerateNewKeyPair().PublicKey
  121. copy(id[:], pk)
  122. return p2p.NewPeer(id, "test peer", []p2p.Cap{})
  123. }
  124. type ethProtocolTester struct {
  125. quit chan error
  126. rw *testMsgReadWriter // p2p.MsgReadWriter
  127. txPool *testTxPool // txPool
  128. chainManager *testChainManager // chainManager
  129. blockPool *testBlockPool // blockPool
  130. t *testing.T
  131. }
  132. func newEth(t *testing.T) *ethProtocolTester {
  133. return &ethProtocolTester{
  134. quit: make(chan error),
  135. rw: &testMsgReadWriter{in: make(chan p2p.Msg, 10)},
  136. txPool: &testTxPool{},
  137. chainManager: &testChainManager{},
  138. blockPool: &testBlockPool{},
  139. t: t,
  140. }
  141. }
  142. func (self *ethProtocolTester) reset() {
  143. self.rw = &testMsgReadWriter{in: make(chan p2p.Msg, 10)}
  144. self.quit = make(chan error)
  145. }
  146. func (self *ethProtocolTester) checkError(expCode int, delay time.Duration) (err error) {
  147. var timer = time.After(delay)
  148. select {
  149. case err = <-self.quit:
  150. case <-timer:
  151. self.t.Errorf("no error after %v, expected %v", delay, expCode)
  152. return
  153. }
  154. perr, ok := err.(*errs.Error)
  155. if ok && perr != nil {
  156. if code := perr.Code; code != expCode {
  157. self.t.Errorf("expected protocol error (code %v), got %v (%v)", expCode, code, err)
  158. }
  159. } else {
  160. self.t.Errorf("expected protocol error (code %v), got %v", expCode, err)
  161. }
  162. return
  163. }
  164. func (self *ethProtocolTester) In(msg p2p.Msg) {
  165. self.rw.In(msg)
  166. }
  167. func (self *ethProtocolTester) Out() (p2p.Msg, bool) {
  168. return self.rw.Out()
  169. }
  170. func (self *ethProtocolTester) checkMsg(i int, code uint64, val interface{}) (msg p2p.Msg) {
  171. if i >= len(self.rw.out) {
  172. self.t.Errorf("expected at least %v msgs, got %v", i, len(self.rw.out))
  173. return
  174. }
  175. msg = self.rw.out[i]
  176. if msg.Code != code {
  177. self.t.Errorf("expected msg code %v, got %v", code, msg.Code)
  178. }
  179. if val != nil {
  180. if err := msg.Decode(val); err != nil {
  181. self.t.Errorf("rlp encoding error: %v", err)
  182. }
  183. }
  184. return
  185. }
  186. func (self *ethProtocolTester) run() {
  187. err := runEthProtocol(self.txPool, self.chainManager, self.blockPool, testPeer(), self.rw)
  188. self.quit <- err
  189. }
  190. func TestStatusMsgErrors(t *testing.T) {
  191. logInit()
  192. eth := newEth(t)
  193. td := ethutil.Big1
  194. currentBlock := []byte{1}
  195. genesis := []byte{2}
  196. eth.chainManager.status = func() (*big.Int, []byte, []byte) { return td, currentBlock, genesis }
  197. go eth.run()
  198. statusMsg := p2p.NewMsg(4)
  199. eth.In(statusMsg)
  200. delay := 1 * time.Second
  201. eth.checkError(ErrNoStatusMsg, delay)
  202. var status statusMsgData
  203. eth.checkMsg(0, StatusMsg, &status) // first outgoing msg should be StatusMsg
  204. if status.TD.Cmp(td) != 0 ||
  205. status.ProtocolVersion != ProtocolVersion ||
  206. status.NetworkId != NetworkId ||
  207. status.TD.Cmp(td) != 0 ||
  208. bytes.Compare(status.CurrentBlock, currentBlock) != 0 ||
  209. bytes.Compare(status.GenesisBlock, genesis) != 0 {
  210. t.Errorf("incorrect outgoing status")
  211. }
  212. eth.reset()
  213. go eth.run()
  214. statusMsg = p2p.NewMsg(0, uint32(48), uint32(0), td, currentBlock, genesis)
  215. eth.In(statusMsg)
  216. eth.checkError(ErrProtocolVersionMismatch, delay)
  217. eth.reset()
  218. go eth.run()
  219. statusMsg = p2p.NewMsg(0, uint32(49), uint32(1), td, currentBlock, genesis)
  220. eth.In(statusMsg)
  221. eth.checkError(ErrNetworkIdMismatch, delay)
  222. eth.reset()
  223. go eth.run()
  224. statusMsg = p2p.NewMsg(0, uint32(49), uint32(0), td, currentBlock, []byte{3})
  225. eth.In(statusMsg)
  226. eth.checkError(ErrGenesisBlockMismatch, delay)
  227. }