iterator.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package trie
  2. import (
  3. "bytes"
  4. )
  5. type Iterator struct {
  6. trie *Trie
  7. Key []byte
  8. Value []byte
  9. }
  10. func NewIterator(trie *Trie) *Iterator {
  11. return &Iterator{trie: trie, Key: nil}
  12. }
  13. func (self *Iterator) Next() bool {
  14. self.trie.mu.Lock()
  15. defer self.trie.mu.Unlock()
  16. isIterStart := false
  17. if self.Key == nil {
  18. isIterStart = true
  19. self.Key = make([]byte, 32)
  20. }
  21. key := RemTerm(CompactHexDecode(string(self.Key)))
  22. k := self.next(self.trie.root, key, isIterStart)
  23. self.Key = []byte(DecodeCompact(k))
  24. return len(k) > 0
  25. }
  26. func (self *Iterator) next(node Node, key []byte, isIterStart bool) []byte {
  27. if node == nil {
  28. return nil
  29. }
  30. switch node := node.(type) {
  31. case *FullNode:
  32. if len(key) > 0 {
  33. k := self.next(node.branch(key[0]), key[1:], isIterStart)
  34. if k != nil {
  35. return append([]byte{key[0]}, k...)
  36. }
  37. }
  38. var r byte
  39. if len(key) > 0 {
  40. r = key[0] + 1
  41. }
  42. for i := r; i < 16; i++ {
  43. k := self.key(node.branch(byte(i)))
  44. if k != nil {
  45. return append([]byte{i}, k...)
  46. }
  47. }
  48. case *ShortNode:
  49. k := RemTerm(node.Key())
  50. if vnode, ok := node.Value().(*ValueNode); ok {
  51. switch bytes.Compare([]byte(k), key) {
  52. case 0:
  53. if isIterStart {
  54. self.Value = vnode.Val()
  55. return k
  56. }
  57. case 1:
  58. self.Value = vnode.Val()
  59. return k
  60. }
  61. } else {
  62. cnode := node.Value()
  63. var ret []byte
  64. skey := key[len(k):]
  65. if BeginsWith(key, k) {
  66. ret = self.next(cnode, skey, isIterStart)
  67. } else if bytes.Compare(k, key[:len(k)]) > 0 {
  68. return self.key(node)
  69. }
  70. if ret != nil {
  71. return append(k, ret...)
  72. }
  73. }
  74. }
  75. return nil
  76. }
  77. func (self *Iterator) key(node Node) []byte {
  78. switch node := node.(type) {
  79. case *ShortNode:
  80. // Leaf node
  81. if vnode, ok := node.Value().(*ValueNode); ok {
  82. k := RemTerm(node.Key())
  83. self.Value = vnode.Val()
  84. return k
  85. } else {
  86. k := RemTerm(node.Key())
  87. return append(k, self.key(node.Value())...)
  88. }
  89. case *FullNode:
  90. if node.Value() != nil {
  91. self.Value = node.Value().(*ValueNode).Val()
  92. return []byte{16}
  93. }
  94. for i := 0; i < 16; i++ {
  95. k := self.key(node.branch(byte(i)))
  96. if k != nil {
  97. return append([]byte{byte(i)}, k...)
  98. }
  99. }
  100. }
  101. return nil
  102. }