bmt_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. // Copyright 2017 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package bmt
  17. import (
  18. "bytes"
  19. crand "crypto/rand"
  20. "encoding/binary"
  21. "fmt"
  22. "io"
  23. "math/rand"
  24. "sync"
  25. "sync/atomic"
  26. "testing"
  27. "time"
  28. "github.com/ethereum/go-ethereum/crypto/sha3"
  29. )
  30. // the actual data length generated (could be longer than max datalength of the BMT)
  31. const BufferSize = 4128
  32. var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
  33. // calculates the Keccak256 SHA3 hash of the data
  34. func sha3hash(data ...[]byte) []byte {
  35. h := sha3.NewKeccak256()
  36. return doSum(h, nil, data...)
  37. }
  38. // TestRefHasher tests that the RefHasher computes the expected BMT hash for
  39. // some small data lengths
  40. func TestRefHasher(t *testing.T) {
  41. // the test struct is used to specify the expected BMT hash for
  42. // segment counts between from and to and lengths from 1 to datalength
  43. type test struct {
  44. from int
  45. to int
  46. expected func([]byte) []byte
  47. }
  48. var tests []*test
  49. // all lengths in [0,64] should be:
  50. //
  51. // sha3hash(data)
  52. //
  53. tests = append(tests, &test{
  54. from: 1,
  55. to: 2,
  56. expected: func(d []byte) []byte {
  57. data := make([]byte, 64)
  58. copy(data, d)
  59. return sha3hash(data)
  60. },
  61. })
  62. // all lengths in [3,4] should be:
  63. //
  64. // sha3hash(
  65. // sha3hash(data[:64])
  66. // sha3hash(data[64:])
  67. // )
  68. //
  69. tests = append(tests, &test{
  70. from: 3,
  71. to: 4,
  72. expected: func(d []byte) []byte {
  73. data := make([]byte, 128)
  74. copy(data, d)
  75. return sha3hash(sha3hash(data[:64]), sha3hash(data[64:]))
  76. },
  77. })
  78. // all segmentCounts in [5,8] should be:
  79. //
  80. // sha3hash(
  81. // sha3hash(
  82. // sha3hash(data[:64])
  83. // sha3hash(data[64:128])
  84. // )
  85. // sha3hash(
  86. // sha3hash(data[128:192])
  87. // sha3hash(data[192:])
  88. // )
  89. // )
  90. //
  91. tests = append(tests, &test{
  92. from: 5,
  93. to: 8,
  94. expected: func(d []byte) []byte {
  95. data := make([]byte, 256)
  96. copy(data, d)
  97. return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:])))
  98. },
  99. })
  100. // run the tests
  101. for _, x := range tests {
  102. for segmentCount := x.from; segmentCount <= x.to; segmentCount++ {
  103. for length := 1; length <= segmentCount*32; length++ {
  104. t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) {
  105. data := make([]byte, length)
  106. if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF {
  107. t.Fatal(err)
  108. }
  109. expected := x.expected(data)
  110. actual := NewRefHasher(sha3.NewKeccak256, segmentCount).Hash(data)
  111. if !bytes.Equal(actual, expected) {
  112. t.Fatalf("expected %x, got %x", expected, actual)
  113. }
  114. })
  115. }
  116. }
  117. }
  118. }
  119. // tests if hasher responds with correct hash comparing the reference implementation return value
  120. func TestHasherEmptyData(t *testing.T) {
  121. hasher := sha3.NewKeccak256
  122. var data []byte
  123. for _, count := range counts {
  124. t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
  125. pool := NewTreePool(hasher, count, PoolSize)
  126. defer pool.Drain(0)
  127. bmt := New(pool)
  128. rbmt := NewRefHasher(hasher, count)
  129. refHash := rbmt.Hash(data)
  130. expHash := syncHash(bmt, nil, data)
  131. if !bytes.Equal(expHash, refHash) {
  132. t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
  133. }
  134. })
  135. }
  136. }
  137. // tests sequential write with entire max size written in one go
  138. func TestSyncHasherCorrectness(t *testing.T) {
  139. data := newData(BufferSize)
  140. hasher := sha3.NewKeccak256
  141. size := hasher().Size()
  142. var err error
  143. for _, count := range counts {
  144. t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
  145. max := count * size
  146. var incr int
  147. capacity := 1
  148. pool := NewTreePool(hasher, count, capacity)
  149. defer pool.Drain(0)
  150. for n := 0; n <= max; n += incr {
  151. incr = 1 + rand.Intn(5)
  152. bmt := New(pool)
  153. err = testHasherCorrectness(bmt, hasher, data, n, count)
  154. if err != nil {
  155. t.Fatal(err)
  156. }
  157. }
  158. })
  159. }
  160. }
  161. // tests order-neutral concurrent writes with entire max size written in one go
  162. func TestAsyncCorrectness(t *testing.T) {
  163. data := newData(BufferSize)
  164. hasher := sha3.NewKeccak256
  165. size := hasher().Size()
  166. whs := []whenHash{first, last, random}
  167. for _, double := range []bool{false, true} {
  168. for _, wh := range whs {
  169. for _, count := range counts {
  170. t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) {
  171. max := count * size
  172. var incr int
  173. capacity := 1
  174. pool := NewTreePool(hasher, count, capacity)
  175. defer pool.Drain(0)
  176. for n := 1; n <= max; n += incr {
  177. incr = 1 + rand.Intn(5)
  178. bmt := New(pool)
  179. d := data[:n]
  180. rbmt := NewRefHasher(hasher, count)
  181. exp := rbmt.Hash(d)
  182. got := syncHash(bmt, nil, d)
  183. if !bytes.Equal(got, exp) {
  184. t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got)
  185. }
  186. sw := bmt.NewAsyncWriter(double)
  187. got = asyncHashRandom(sw, nil, d, wh)
  188. if !bytes.Equal(got, exp) {
  189. t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got)
  190. }
  191. }
  192. })
  193. }
  194. }
  195. }
  196. }
  197. // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize
  198. func TestHasherReuse(t *testing.T) {
  199. t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
  200. testHasherReuse(1, t)
  201. })
  202. t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) {
  203. testHasherReuse(PoolSize, t)
  204. })
  205. }
  206. // tests if bmt reuse is not corrupting result
  207. func testHasherReuse(poolsize int, t *testing.T) {
  208. hasher := sha3.NewKeccak256
  209. pool := NewTreePool(hasher, SegmentCount, poolsize)
  210. defer pool.Drain(0)
  211. bmt := New(pool)
  212. for i := 0; i < 100; i++ {
  213. data := newData(BufferSize)
  214. n := rand.Intn(bmt.Size())
  215. err := testHasherCorrectness(bmt, hasher, data, n, SegmentCount)
  216. if err != nil {
  217. t.Fatal(err)
  218. }
  219. }
  220. }
  221. // Tests if pool can be cleanly reused even in concurrent use by several hasher
  222. func TestBMTConcurrentUse(t *testing.T) {
  223. hasher := sha3.NewKeccak256
  224. pool := NewTreePool(hasher, SegmentCount, PoolSize)
  225. defer pool.Drain(0)
  226. cycles := 100
  227. errc := make(chan error)
  228. for i := 0; i < cycles; i++ {
  229. go func() {
  230. bmt := New(pool)
  231. data := newData(BufferSize)
  232. n := rand.Intn(bmt.Size())
  233. errc <- testHasherCorrectness(bmt, hasher, data, n, 128)
  234. }()
  235. }
  236. LOOP:
  237. for {
  238. select {
  239. case <-time.NewTimer(5 * time.Second).C:
  240. t.Fatal("timed out")
  241. case err := <-errc:
  242. if err != nil {
  243. t.Fatal(err)
  244. }
  245. cycles--
  246. if cycles == 0 {
  247. break LOOP
  248. }
  249. }
  250. }
  251. }
  252. // Tests BMT Hasher io.Writer interface is working correctly
  253. // even multiple short random write buffers
  254. func TestBMTWriterBuffers(t *testing.T) {
  255. hasher := sha3.NewKeccak256
  256. for _, count := range counts {
  257. t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
  258. errc := make(chan error)
  259. pool := NewTreePool(hasher, count, PoolSize)
  260. defer pool.Drain(0)
  261. n := count * 32
  262. bmt := New(pool)
  263. data := newData(n)
  264. rbmt := NewRefHasher(hasher, count)
  265. refHash := rbmt.Hash(data)
  266. expHash := syncHash(bmt, nil, data)
  267. if !bytes.Equal(expHash, refHash) {
  268. t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
  269. }
  270. attempts := 10
  271. f := func() error {
  272. bmt := New(pool)
  273. bmt.Reset()
  274. var buflen int
  275. for offset := 0; offset < n; offset += buflen {
  276. buflen = rand.Intn(n-offset) + 1
  277. read, err := bmt.Write(data[offset : offset+buflen])
  278. if err != nil {
  279. return err
  280. }
  281. if read != buflen {
  282. return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read)
  283. }
  284. }
  285. hash := bmt.Sum(nil)
  286. if !bytes.Equal(hash, expHash) {
  287. return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash)
  288. }
  289. return nil
  290. }
  291. for j := 0; j < attempts; j++ {
  292. go func() {
  293. errc <- f()
  294. }()
  295. }
  296. timeout := time.NewTimer(2 * time.Second)
  297. for {
  298. select {
  299. case err := <-errc:
  300. if err != nil {
  301. t.Fatal(err)
  302. }
  303. attempts--
  304. if attempts == 0 {
  305. return
  306. }
  307. case <-timeout.C:
  308. t.Fatalf("timeout")
  309. }
  310. }
  311. })
  312. }
  313. }
  314. // helper function that compares reference and optimised implementations on
  315. // correctness
  316. func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) {
  317. span := make([]byte, 8)
  318. if len(d) < n {
  319. n = len(d)
  320. }
  321. binary.BigEndian.PutUint64(span, uint64(n))
  322. data := d[:n]
  323. rbmt := NewRefHasher(hasher, count)
  324. exp := sha3hash(span, rbmt.Hash(data))
  325. got := syncHash(bmt, span, data)
  326. if !bytes.Equal(got, exp) {
  327. return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
  328. }
  329. return err
  330. }
  331. //
  332. func BenchmarkBMT(t *testing.B) {
  333. for size := 4096; size >= 128; size /= 2 {
  334. t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) {
  335. benchmarkSHA3(t, size)
  336. })
  337. t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) {
  338. benchmarkBMTBaseline(t, size)
  339. })
  340. t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) {
  341. benchmarkRefHasher(t, size)
  342. })
  343. t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) {
  344. benchmarkBMT(t, size)
  345. })
  346. }
  347. }
  348. type whenHash = int
  349. const (
  350. first whenHash = iota
  351. last
  352. random
  353. )
  354. func BenchmarkBMTAsync(t *testing.B) {
  355. whs := []whenHash{first, last, random}
  356. for size := 4096; size >= 128; size /= 2 {
  357. for _, wh := range whs {
  358. for _, double := range []bool{false, true} {
  359. t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) {
  360. benchmarkBMTAsync(t, size, wh, double)
  361. })
  362. }
  363. }
  364. }
  365. }
  366. func BenchmarkPool(t *testing.B) {
  367. caps := []int{1, PoolSize}
  368. for size := 4096; size >= 128; size /= 2 {
  369. for _, c := range caps {
  370. t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) {
  371. benchmarkPool(t, c, size)
  372. })
  373. }
  374. }
  375. }
  376. // benchmarks simple sha3 hash on chunks
  377. func benchmarkSHA3(t *testing.B, n int) {
  378. data := newData(n)
  379. hasher := sha3.NewKeccak256
  380. h := hasher()
  381. t.ReportAllocs()
  382. t.ResetTimer()
  383. for i := 0; i < t.N; i++ {
  384. doSum(h, nil, data)
  385. }
  386. }
  387. // benchmarks the minimum hashing time for a balanced (for simplicity) BMT
  388. // by doing count/segmentsize parallel hashings of 2*segmentsize bytes
  389. // doing it on n PoolSize each reusing the base hasher
  390. // the premise is that this is the minimum computation needed for a BMT
  391. // therefore this serves as a theoretical optimum for concurrent implementations
  392. func benchmarkBMTBaseline(t *testing.B, n int) {
  393. hasher := sha3.NewKeccak256
  394. hashSize := hasher().Size()
  395. data := newData(hashSize)
  396. t.ReportAllocs()
  397. t.ResetTimer()
  398. for i := 0; i < t.N; i++ {
  399. count := int32((n-1)/hashSize + 1)
  400. wg := sync.WaitGroup{}
  401. wg.Add(PoolSize)
  402. var i int32
  403. for j := 0; j < PoolSize; j++ {
  404. go func() {
  405. defer wg.Done()
  406. h := hasher()
  407. for atomic.AddInt32(&i, 1) < count {
  408. doSum(h, nil, data)
  409. }
  410. }()
  411. }
  412. wg.Wait()
  413. }
  414. }
  415. // benchmarks BMT Hasher
  416. func benchmarkBMT(t *testing.B, n int) {
  417. data := newData(n)
  418. hasher := sha3.NewKeccak256
  419. pool := NewTreePool(hasher, SegmentCount, PoolSize)
  420. bmt := New(pool)
  421. t.ReportAllocs()
  422. t.ResetTimer()
  423. for i := 0; i < t.N; i++ {
  424. syncHash(bmt, nil, data)
  425. }
  426. }
  427. // benchmarks BMT hasher with asynchronous concurrent segment/section writes
  428. func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) {
  429. data := newData(n)
  430. hasher := sha3.NewKeccak256
  431. pool := NewTreePool(hasher, SegmentCount, PoolSize)
  432. bmt := New(pool).NewAsyncWriter(double)
  433. idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
  434. shuffle(len(idxs), func(i int, j int) {
  435. idxs[i], idxs[j] = idxs[j], idxs[i]
  436. })
  437. t.ReportAllocs()
  438. t.ResetTimer()
  439. for i := 0; i < t.N; i++ {
  440. asyncHash(bmt, nil, n, wh, idxs, segments)
  441. }
  442. }
  443. // benchmarks 100 concurrent bmt hashes with pool capacity
  444. func benchmarkPool(t *testing.B, poolsize, n int) {
  445. data := newData(n)
  446. hasher := sha3.NewKeccak256
  447. pool := NewTreePool(hasher, SegmentCount, poolsize)
  448. cycles := 100
  449. t.ReportAllocs()
  450. t.ResetTimer()
  451. wg := sync.WaitGroup{}
  452. for i := 0; i < t.N; i++ {
  453. wg.Add(cycles)
  454. for j := 0; j < cycles; j++ {
  455. go func() {
  456. defer wg.Done()
  457. bmt := New(pool)
  458. syncHash(bmt, nil, data)
  459. }()
  460. }
  461. wg.Wait()
  462. }
  463. }
  464. // benchmarks the reference hasher
  465. func benchmarkRefHasher(t *testing.B, n int) {
  466. data := newData(n)
  467. hasher := sha3.NewKeccak256
  468. rbmt := NewRefHasher(hasher, 128)
  469. t.ReportAllocs()
  470. t.ResetTimer()
  471. for i := 0; i < t.N; i++ {
  472. rbmt.Hash(data)
  473. }
  474. }
  475. func newData(bufferSize int) []byte {
  476. data := make([]byte, bufferSize)
  477. _, err := io.ReadFull(crand.Reader, data)
  478. if err != nil {
  479. panic(err.Error())
  480. }
  481. return data
  482. }
  483. // Hash hashes the data and the span using the bmt hasher
  484. func syncHash(h *Hasher, span, data []byte) []byte {
  485. h.ResetWithLength(span)
  486. h.Write(data)
  487. return h.Sum(nil)
  488. }
  489. func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) {
  490. l := len(data)
  491. n := l / secsize
  492. if l%secsize > 0 {
  493. n++
  494. }
  495. for i := 0; i < n; i++ {
  496. idxs = append(idxs, i)
  497. end := (i + 1) * secsize
  498. if end > l {
  499. end = l
  500. }
  501. section := data[i*secsize : end]
  502. segments = append(segments, section)
  503. }
  504. shuffle(n, func(i int, j int) {
  505. idxs[i], idxs[j] = idxs[j], idxs[i]
  506. })
  507. return idxs, segments
  508. }
  509. // splits the input data performs a random shuffle to mock async section writes
  510. func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) {
  511. idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
  512. return asyncHash(bmt, span, len(data), wh, idxs, segments)
  513. }
  514. // mock for async section writes for BMT SectionWriter
  515. // requires a permutation (a random shuffle) of list of all indexes of segments
  516. // and writes them in order to the appropriate section
  517. // the Sum function is called according to the wh parameter (first, last, random [relative to segment writes])
  518. func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) {
  519. bmt.Reset()
  520. if l == 0 {
  521. return bmt.Sum(nil, l, span)
  522. }
  523. c := make(chan []byte, 1)
  524. hashf := func() {
  525. c <- bmt.Sum(nil, l, span)
  526. }
  527. maxsize := len(idxs)
  528. var r int
  529. if wh == random {
  530. r = rand.Intn(maxsize)
  531. }
  532. for i, idx := range idxs {
  533. bmt.Write(idx, segments[idx])
  534. if (wh == first || wh == random) && i == r {
  535. go hashf()
  536. }
  537. }
  538. if wh == last {
  539. return bmt.Sum(nil, l, span)
  540. }
  541. return <-c
  542. }
  543. // this is also in swarm/network_test.go
  544. // shuffle pseudo-randomizes the order of elements.
  545. // n is the number of elements. Shuffle panics if n < 0.
  546. // swap swaps the elements with indexes i and j.
  547. func shuffle(n int, swap func(i, j int)) {
  548. if n < 0 {
  549. panic("invalid argument to Shuffle")
  550. }
  551. // Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
  552. // Shuffle really ought not be called with n that doesn't fit in 32 bits.
  553. // Not only will it take a very long time, but with 2³¹! possible permutations,
  554. // there's no way that any PRNG can have a big enough internal state to
  555. // generate even a minuscule percentage of the possible permutations.
  556. // Nevertheless, the right API signature accepts an int n, so handle it as best we can.
  557. i := n - 1
  558. for ; i > 1<<31-1-1; i-- {
  559. j := int(rand.Int63n(int64(i + 1)))
  560. swap(i, j)
  561. }
  562. for ; i > 0; i-- {
  563. j := int(rand.Int31n(int32(i + 1)))
  564. swap(i, j)
  565. }
  566. }