consensus_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package beacon
  2. import (
  3. "fmt"
  4. "math/big"
  5. "testing"
  6. "github.com/ethereum/go-ethereum/common"
  7. "github.com/ethereum/go-ethereum/consensus"
  8. "github.com/ethereum/go-ethereum/core/types"
  9. "github.com/ethereum/go-ethereum/params"
  10. )
  11. type mockChain struct {
  12. config *params.ChainConfig
  13. tds map[uint64]*big.Int
  14. }
  15. func newMockChain() *mockChain {
  16. return &mockChain{
  17. config: new(params.ChainConfig),
  18. tds: make(map[uint64]*big.Int),
  19. }
  20. }
  21. func (m *mockChain) Config() *params.ChainConfig {
  22. return m.config
  23. }
  24. func (m *mockChain) CurrentHeader() *types.Header { panic("not implemented") }
  25. func (m *mockChain) GetHeader(hash common.Hash, number uint64) *types.Header {
  26. panic("not implemented")
  27. }
  28. func (m *mockChain) GetHeaderByNumber(number uint64) *types.Header { panic("not implemented") }
  29. func (m *mockChain) GetHeaderByHash(hash common.Hash) *types.Header { panic("not implemented") }
  30. func (m *mockChain) GetTd(hash common.Hash, number uint64) *big.Int {
  31. num, ok := m.tds[number]
  32. if ok {
  33. return new(big.Int).Set(num)
  34. }
  35. return nil
  36. }
  37. func TestVerifyTerminalBlock(t *testing.T) {
  38. chain := newMockChain()
  39. chain.tds[0] = big.NewInt(10)
  40. chain.config.TerminalTotalDifficulty = big.NewInt(50)
  41. tests := []struct {
  42. preHeaders []*types.Header
  43. ttd *big.Int
  44. err error
  45. index int
  46. }{
  47. // valid ttd
  48. {
  49. preHeaders: []*types.Header{
  50. {Number: big.NewInt(1), Difficulty: big.NewInt(10)},
  51. {Number: big.NewInt(2), Difficulty: big.NewInt(10)},
  52. {Number: big.NewInt(3), Difficulty: big.NewInt(10)},
  53. {Number: big.NewInt(4), Difficulty: big.NewInt(10)},
  54. },
  55. ttd: big.NewInt(50),
  56. },
  57. // last block doesn't reach ttd
  58. {
  59. preHeaders: []*types.Header{
  60. {Number: big.NewInt(1), Difficulty: big.NewInt(10)},
  61. {Number: big.NewInt(2), Difficulty: big.NewInt(10)},
  62. {Number: big.NewInt(3), Difficulty: big.NewInt(10)},
  63. {Number: big.NewInt(4), Difficulty: big.NewInt(9)},
  64. },
  65. ttd: big.NewInt(50),
  66. err: consensus.ErrInvalidTerminalBlock,
  67. index: 3,
  68. },
  69. // two blocks reach ttd
  70. {
  71. preHeaders: []*types.Header{
  72. {Number: big.NewInt(1), Difficulty: big.NewInt(10)},
  73. {Number: big.NewInt(2), Difficulty: big.NewInt(10)},
  74. {Number: big.NewInt(3), Difficulty: big.NewInt(20)},
  75. {Number: big.NewInt(4), Difficulty: big.NewInt(10)},
  76. },
  77. ttd: big.NewInt(50),
  78. err: consensus.ErrInvalidTerminalBlock,
  79. index: 3,
  80. },
  81. // three blocks reach ttd
  82. {
  83. preHeaders: []*types.Header{
  84. {Number: big.NewInt(1), Difficulty: big.NewInt(10)},
  85. {Number: big.NewInt(2), Difficulty: big.NewInt(10)},
  86. {Number: big.NewInt(3), Difficulty: big.NewInt(20)},
  87. {Number: big.NewInt(4), Difficulty: big.NewInt(10)},
  88. {Number: big.NewInt(4), Difficulty: big.NewInt(10)},
  89. },
  90. ttd: big.NewInt(50),
  91. err: consensus.ErrInvalidTerminalBlock,
  92. index: 3,
  93. },
  94. // parent reached ttd
  95. {
  96. preHeaders: []*types.Header{
  97. {Number: big.NewInt(1), Difficulty: big.NewInt(10)},
  98. },
  99. ttd: big.NewInt(9),
  100. err: consensus.ErrInvalidTerminalBlock,
  101. index: 0,
  102. },
  103. // unknown parent
  104. {
  105. preHeaders: []*types.Header{
  106. {Number: big.NewInt(4), Difficulty: big.NewInt(10)},
  107. },
  108. ttd: big.NewInt(9),
  109. err: consensus.ErrUnknownAncestor,
  110. index: 0,
  111. },
  112. }
  113. for i, test := range tests {
  114. fmt.Printf("Test: %v\n", i)
  115. chain.config.TerminalTotalDifficulty = test.ttd
  116. index, err := verifyTerminalPoWBlock(chain, test.preHeaders)
  117. if err != test.err {
  118. t.Fatalf("Invalid error encountered, expected %v got %v", test.err, err)
  119. }
  120. if index != test.index {
  121. t.Fatalf("Invalid index, expected %v got %v", test.index, index)
  122. }
  123. }
  124. }