bmt_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. // Copyright 2017 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package bmt
  17. import (
  18. "bytes"
  19. "encoding/binary"
  20. "fmt"
  21. "math/rand"
  22. "sync"
  23. "sync/atomic"
  24. "testing"
  25. "time"
  26. "github.com/ethereum/go-ethereum/swarm/testutil"
  27. "golang.org/x/crypto/sha3"
  28. )
  29. // the actual data length generated (could be longer than max datalength of the BMT)
  30. const BufferSize = 4128
  31. const (
  32. // segmentCount is the maximum number of segments of the underlying chunk
  33. // Should be equal to max-chunk-data-size / hash-size
  34. // Currently set to 128 == 4096 (default chunk size) / 32 (sha3.keccak256 size)
  35. segmentCount = 128
  36. )
  37. var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
  38. // calculates the Keccak256 SHA3 hash of the data
  39. func sha3hash(data ...[]byte) []byte {
  40. h := sha3.NewLegacyKeccak256()
  41. return doSum(h, nil, data...)
  42. }
  43. // TestRefHasher tests that the RefHasher computes the expected BMT hash for
  44. // some small data lengths
  45. func TestRefHasher(t *testing.T) {
  46. // the test struct is used to specify the expected BMT hash for
  47. // segment counts between from and to and lengths from 1 to datalength
  48. type test struct {
  49. from int
  50. to int
  51. expected func([]byte) []byte
  52. }
  53. var tests []*test
  54. // all lengths in [0,64] should be:
  55. //
  56. // sha3hash(data)
  57. //
  58. tests = append(tests, &test{
  59. from: 1,
  60. to: 2,
  61. expected: func(d []byte) []byte {
  62. data := make([]byte, 64)
  63. copy(data, d)
  64. return sha3hash(data)
  65. },
  66. })
  67. // all lengths in [3,4] should be:
  68. //
  69. // sha3hash(
  70. // sha3hash(data[:64])
  71. // sha3hash(data[64:])
  72. // )
  73. //
  74. tests = append(tests, &test{
  75. from: 3,
  76. to: 4,
  77. expected: func(d []byte) []byte {
  78. data := make([]byte, 128)
  79. copy(data, d)
  80. return sha3hash(sha3hash(data[:64]), sha3hash(data[64:]))
  81. },
  82. })
  83. // all segmentCounts in [5,8] should be:
  84. //
  85. // sha3hash(
  86. // sha3hash(
  87. // sha3hash(data[:64])
  88. // sha3hash(data[64:128])
  89. // )
  90. // sha3hash(
  91. // sha3hash(data[128:192])
  92. // sha3hash(data[192:])
  93. // )
  94. // )
  95. //
  96. tests = append(tests, &test{
  97. from: 5,
  98. to: 8,
  99. expected: func(d []byte) []byte {
  100. data := make([]byte, 256)
  101. copy(data, d)
  102. return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:])))
  103. },
  104. })
  105. // run the tests
  106. for i, x := range tests {
  107. for segmentCount := x.from; segmentCount <= x.to; segmentCount++ {
  108. for length := 1; length <= segmentCount*32; length++ {
  109. t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) {
  110. data := testutil.RandomBytes(i, length)
  111. expected := x.expected(data)
  112. actual := NewRefHasher(sha3.NewLegacyKeccak256, segmentCount).Hash(data)
  113. if !bytes.Equal(actual, expected) {
  114. t.Fatalf("expected %x, got %x", expected, actual)
  115. }
  116. })
  117. }
  118. }
  119. }
  120. }
  121. // tests if hasher responds with correct hash comparing the reference implementation return value
  122. func TestHasherEmptyData(t *testing.T) {
  123. hasher := sha3.NewLegacyKeccak256
  124. var data []byte
  125. for _, count := range counts {
  126. t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
  127. pool := NewTreePool(hasher, count, PoolSize)
  128. defer pool.Drain(0)
  129. bmt := New(pool)
  130. rbmt := NewRefHasher(hasher, count)
  131. refHash := rbmt.Hash(data)
  132. expHash := syncHash(bmt, nil, data)
  133. if !bytes.Equal(expHash, refHash) {
  134. t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
  135. }
  136. })
  137. }
  138. }
  139. // tests sequential write with entire max size written in one go
  140. func TestSyncHasherCorrectness(t *testing.T) {
  141. data := testutil.RandomBytes(1, BufferSize)
  142. hasher := sha3.NewLegacyKeccak256
  143. size := hasher().Size()
  144. var err error
  145. for _, count := range counts {
  146. t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
  147. max := count * size
  148. var incr int
  149. capacity := 1
  150. pool := NewTreePool(hasher, count, capacity)
  151. defer pool.Drain(0)
  152. for n := 0; n <= max; n += incr {
  153. incr = 1 + rand.Intn(5)
  154. bmt := New(pool)
  155. err = testHasherCorrectness(bmt, hasher, data, n, count)
  156. if err != nil {
  157. t.Fatal(err)
  158. }
  159. }
  160. })
  161. }
  162. }
  163. // tests order-neutral concurrent writes with entire max size written in one go
  164. func TestAsyncCorrectness(t *testing.T) {
  165. data := testutil.RandomBytes(1, BufferSize)
  166. hasher := sha3.NewLegacyKeccak256
  167. size := hasher().Size()
  168. whs := []whenHash{first, last, random}
  169. for _, double := range []bool{false, true} {
  170. for _, wh := range whs {
  171. for _, count := range counts {
  172. t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) {
  173. max := count * size
  174. var incr int
  175. capacity := 1
  176. pool := NewTreePool(hasher, count, capacity)
  177. defer pool.Drain(0)
  178. for n := 1; n <= max; n += incr {
  179. incr = 1 + rand.Intn(5)
  180. bmt := New(pool)
  181. d := data[:n]
  182. rbmt := NewRefHasher(hasher, count)
  183. exp := rbmt.Hash(d)
  184. got := syncHash(bmt, nil, d)
  185. if !bytes.Equal(got, exp) {
  186. t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got)
  187. }
  188. sw := bmt.NewAsyncWriter(double)
  189. got = asyncHashRandom(sw, nil, d, wh)
  190. if !bytes.Equal(got, exp) {
  191. t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got)
  192. }
  193. }
  194. })
  195. }
  196. }
  197. }
  198. }
  199. // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize
  200. func TestHasherReuse(t *testing.T) {
  201. t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
  202. testHasherReuse(1, t)
  203. })
  204. t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) {
  205. testHasherReuse(PoolSize, t)
  206. })
  207. }
  208. // tests if bmt reuse is not corrupting result
  209. func testHasherReuse(poolsize int, t *testing.T) {
  210. hasher := sha3.NewLegacyKeccak256
  211. pool := NewTreePool(hasher, segmentCount, poolsize)
  212. defer pool.Drain(0)
  213. bmt := New(pool)
  214. for i := 0; i < 100; i++ {
  215. data := testutil.RandomBytes(1, BufferSize)
  216. n := rand.Intn(bmt.Size())
  217. err := testHasherCorrectness(bmt, hasher, data, n, segmentCount)
  218. if err != nil {
  219. t.Fatal(err)
  220. }
  221. }
  222. }
  223. // Tests if pool can be cleanly reused even in concurrent use by several hasher
  224. func TestBMTConcurrentUse(t *testing.T) {
  225. hasher := sha3.NewLegacyKeccak256
  226. pool := NewTreePool(hasher, segmentCount, PoolSize)
  227. defer pool.Drain(0)
  228. cycles := 100
  229. errc := make(chan error)
  230. for i := 0; i < cycles; i++ {
  231. go func() {
  232. bmt := New(pool)
  233. data := testutil.RandomBytes(1, BufferSize)
  234. n := rand.Intn(bmt.Size())
  235. errc <- testHasherCorrectness(bmt, hasher, data, n, 128)
  236. }()
  237. }
  238. LOOP:
  239. for {
  240. select {
  241. case <-time.NewTimer(5 * time.Second).C:
  242. t.Fatal("timed out")
  243. case err := <-errc:
  244. if err != nil {
  245. t.Fatal(err)
  246. }
  247. cycles--
  248. if cycles == 0 {
  249. break LOOP
  250. }
  251. }
  252. }
  253. }
  254. // Tests BMT Hasher io.Writer interface is working correctly
  255. // even multiple short random write buffers
  256. func TestBMTWriterBuffers(t *testing.T) {
  257. hasher := sha3.NewLegacyKeccak256
  258. for _, count := range counts {
  259. t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
  260. errc := make(chan error)
  261. pool := NewTreePool(hasher, count, PoolSize)
  262. defer pool.Drain(0)
  263. n := count * 32
  264. bmt := New(pool)
  265. data := testutil.RandomBytes(1, n)
  266. rbmt := NewRefHasher(hasher, count)
  267. refHash := rbmt.Hash(data)
  268. expHash := syncHash(bmt, nil, data)
  269. if !bytes.Equal(expHash, refHash) {
  270. t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
  271. }
  272. attempts := 10
  273. f := func() error {
  274. bmt := New(pool)
  275. bmt.Reset()
  276. var buflen int
  277. for offset := 0; offset < n; offset += buflen {
  278. buflen = rand.Intn(n-offset) + 1
  279. read, err := bmt.Write(data[offset : offset+buflen])
  280. if err != nil {
  281. return err
  282. }
  283. if read != buflen {
  284. return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read)
  285. }
  286. }
  287. hash := bmt.Sum(nil)
  288. if !bytes.Equal(hash, expHash) {
  289. return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash)
  290. }
  291. return nil
  292. }
  293. for j := 0; j < attempts; j++ {
  294. go func() {
  295. errc <- f()
  296. }()
  297. }
  298. timeout := time.NewTimer(2 * time.Second)
  299. for {
  300. select {
  301. case err := <-errc:
  302. if err != nil {
  303. t.Fatal(err)
  304. }
  305. attempts--
  306. if attempts == 0 {
  307. return
  308. }
  309. case <-timeout.C:
  310. t.Fatalf("timeout")
  311. }
  312. }
  313. })
  314. }
  315. }
  316. // helper function that compares reference and optimised implementations on
  317. // correctness
  318. func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) {
  319. span := make([]byte, 8)
  320. if len(d) < n {
  321. n = len(d)
  322. }
  323. binary.BigEndian.PutUint64(span, uint64(n))
  324. data := d[:n]
  325. rbmt := NewRefHasher(hasher, count)
  326. exp := sha3hash(span, rbmt.Hash(data))
  327. got := syncHash(bmt, span, data)
  328. if !bytes.Equal(got, exp) {
  329. return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
  330. }
  331. return err
  332. }
  333. //
  334. func BenchmarkBMT(t *testing.B) {
  335. for size := 4096; size >= 128; size /= 2 {
  336. t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) {
  337. benchmarkSHA3(t, size)
  338. })
  339. t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) {
  340. benchmarkBMTBaseline(t, size)
  341. })
  342. t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) {
  343. benchmarkRefHasher(t, size)
  344. })
  345. t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) {
  346. benchmarkBMT(t, size)
  347. })
  348. }
  349. }
  350. type whenHash = int
  351. const (
  352. first whenHash = iota
  353. last
  354. random
  355. )
  356. func BenchmarkBMTAsync(t *testing.B) {
  357. whs := []whenHash{first, last, random}
  358. for size := 4096; size >= 128; size /= 2 {
  359. for _, wh := range whs {
  360. for _, double := range []bool{false, true} {
  361. t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) {
  362. benchmarkBMTAsync(t, size, wh, double)
  363. })
  364. }
  365. }
  366. }
  367. }
  368. func BenchmarkPool(t *testing.B) {
  369. caps := []int{1, PoolSize}
  370. for size := 4096; size >= 128; size /= 2 {
  371. for _, c := range caps {
  372. t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) {
  373. benchmarkPool(t, c, size)
  374. })
  375. }
  376. }
  377. }
  378. // benchmarks simple sha3 hash on chunks
  379. func benchmarkSHA3(t *testing.B, n int) {
  380. data := testutil.RandomBytes(1, n)
  381. hasher := sha3.NewLegacyKeccak256
  382. h := hasher()
  383. t.ReportAllocs()
  384. t.ResetTimer()
  385. for i := 0; i < t.N; i++ {
  386. doSum(h, nil, data)
  387. }
  388. }
  389. // benchmarks the minimum hashing time for a balanced (for simplicity) BMT
  390. // by doing count/segmentsize parallel hashings of 2*segmentsize bytes
  391. // doing it on n PoolSize each reusing the base hasher
  392. // the premise is that this is the minimum computation needed for a BMT
  393. // therefore this serves as a theoretical optimum for concurrent implementations
  394. func benchmarkBMTBaseline(t *testing.B, n int) {
  395. hasher := sha3.NewLegacyKeccak256
  396. hashSize := hasher().Size()
  397. data := testutil.RandomBytes(1, hashSize)
  398. t.ReportAllocs()
  399. t.ResetTimer()
  400. for i := 0; i < t.N; i++ {
  401. count := int32((n-1)/hashSize + 1)
  402. wg := sync.WaitGroup{}
  403. wg.Add(PoolSize)
  404. var i int32
  405. for j := 0; j < PoolSize; j++ {
  406. go func() {
  407. defer wg.Done()
  408. h := hasher()
  409. for atomic.AddInt32(&i, 1) < count {
  410. doSum(h, nil, data)
  411. }
  412. }()
  413. }
  414. wg.Wait()
  415. }
  416. }
  417. // benchmarks BMT Hasher
  418. func benchmarkBMT(t *testing.B, n int) {
  419. data := testutil.RandomBytes(1, n)
  420. hasher := sha3.NewLegacyKeccak256
  421. pool := NewTreePool(hasher, segmentCount, PoolSize)
  422. bmt := New(pool)
  423. t.ReportAllocs()
  424. t.ResetTimer()
  425. for i := 0; i < t.N; i++ {
  426. syncHash(bmt, nil, data)
  427. }
  428. }
  429. // benchmarks BMT hasher with asynchronous concurrent segment/section writes
  430. func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) {
  431. data := testutil.RandomBytes(1, n)
  432. hasher := sha3.NewLegacyKeccak256
  433. pool := NewTreePool(hasher, segmentCount, PoolSize)
  434. bmt := New(pool).NewAsyncWriter(double)
  435. idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
  436. rand.Shuffle(len(idxs), func(i int, j int) {
  437. idxs[i], idxs[j] = idxs[j], idxs[i]
  438. })
  439. t.ReportAllocs()
  440. t.ResetTimer()
  441. for i := 0; i < t.N; i++ {
  442. asyncHash(bmt, nil, n, wh, idxs, segments)
  443. }
  444. }
  445. // benchmarks 100 concurrent bmt hashes with pool capacity
  446. func benchmarkPool(t *testing.B, poolsize, n int) {
  447. data := testutil.RandomBytes(1, n)
  448. hasher := sha3.NewLegacyKeccak256
  449. pool := NewTreePool(hasher, segmentCount, poolsize)
  450. cycles := 100
  451. t.ReportAllocs()
  452. t.ResetTimer()
  453. wg := sync.WaitGroup{}
  454. for i := 0; i < t.N; i++ {
  455. wg.Add(cycles)
  456. for j := 0; j < cycles; j++ {
  457. go func() {
  458. defer wg.Done()
  459. bmt := New(pool)
  460. syncHash(bmt, nil, data)
  461. }()
  462. }
  463. wg.Wait()
  464. }
  465. }
  466. // benchmarks the reference hasher
  467. func benchmarkRefHasher(t *testing.B, n int) {
  468. data := testutil.RandomBytes(1, n)
  469. hasher := sha3.NewLegacyKeccak256
  470. rbmt := NewRefHasher(hasher, 128)
  471. t.ReportAllocs()
  472. t.ResetTimer()
  473. for i := 0; i < t.N; i++ {
  474. rbmt.Hash(data)
  475. }
  476. }
  477. // Hash hashes the data and the span using the bmt hasher
  478. func syncHash(h *Hasher, span, data []byte) []byte {
  479. h.ResetWithLength(span)
  480. h.Write(data)
  481. return h.Sum(nil)
  482. }
  483. func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) {
  484. l := len(data)
  485. n := l / secsize
  486. if l%secsize > 0 {
  487. n++
  488. }
  489. for i := 0; i < n; i++ {
  490. idxs = append(idxs, i)
  491. end := (i + 1) * secsize
  492. if end > l {
  493. end = l
  494. }
  495. section := data[i*secsize : end]
  496. segments = append(segments, section)
  497. }
  498. rand.Shuffle(n, func(i int, j int) {
  499. idxs[i], idxs[j] = idxs[j], idxs[i]
  500. })
  501. return idxs, segments
  502. }
  503. // splits the input data performs a random shuffle to mock async section writes
  504. func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) {
  505. idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
  506. return asyncHash(bmt, span, len(data), wh, idxs, segments)
  507. }
  508. // mock for async section writes for BMT SectionWriter
  509. // requires a permutation (a random shuffle) of list of all indexes of segments
  510. // and writes them in order to the appropriate section
  511. // the Sum function is called according to the wh parameter (first, last, random [relative to segment writes])
  512. func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) {
  513. bmt.Reset()
  514. if l == 0 {
  515. return bmt.Sum(nil, l, span)
  516. }
  517. c := make(chan []byte, 1)
  518. hashf := func() {
  519. c <- bmt.Sum(nil, l, span)
  520. }
  521. maxsize := len(idxs)
  522. var r int
  523. if wh == random {
  524. r = rand.Intn(maxsize)
  525. }
  526. for i, idx := range idxs {
  527. bmt.Write(idx, segments[idx])
  528. if (wh == first || wh == random) && i == r {
  529. go hashf()
  530. }
  531. }
  532. if wh == last {
  533. return bmt.Sum(nil, l, span)
  534. }
  535. return <-c
  536. }