iterator.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package trie
  2. import "bytes"
  3. type Iterator struct {
  4. trie *Trie
  5. Key []byte
  6. Value []byte
  7. }
  8. func NewIterator(trie *Trie) *Iterator {
  9. return &Iterator{trie: trie, Key: make([]byte, 32)}
  10. }
  11. func (self *Iterator) Next() bool {
  12. self.trie.mu.Lock()
  13. defer self.trie.mu.Unlock()
  14. key := RemTerm(CompactHexDecode(string(self.Key)))
  15. k := self.next(self.trie.root, key)
  16. self.Key = []byte(DecodeCompact(k))
  17. return len(k) > 0
  18. }
  19. func (self *Iterator) next(node Node, key []byte) []byte {
  20. if node == nil {
  21. return nil
  22. }
  23. switch node := node.(type) {
  24. case *FullNode:
  25. if len(key) > 0 {
  26. k := self.next(node.branch(key[0]), key[1:])
  27. if k != nil {
  28. return append([]byte{key[0]}, k...)
  29. }
  30. }
  31. var r byte
  32. if len(key) > 0 {
  33. r = key[0] + 1
  34. }
  35. for i := r; i < 16; i++ {
  36. k := self.key(node.branch(byte(i)))
  37. if k != nil {
  38. return append([]byte{i}, k...)
  39. }
  40. }
  41. case *ShortNode:
  42. k := RemTerm(node.Key())
  43. if vnode, ok := node.Value().(*ValueNode); ok {
  44. if bytes.Compare([]byte(k), key) > 0 {
  45. self.Value = vnode.Val()
  46. return k
  47. }
  48. } else {
  49. cnode := node.Value()
  50. var ret []byte
  51. skey := key[len(k):]
  52. if BeginsWith(key, k) {
  53. ret = self.next(cnode, skey)
  54. } else if bytes.Compare(k, key[:len(k)]) > 0 {
  55. ret = self.key(node)
  56. }
  57. if ret != nil {
  58. return append(k, ret...)
  59. }
  60. }
  61. }
  62. return nil
  63. }
  64. func (self *Iterator) key(node Node) []byte {
  65. switch node := node.(type) {
  66. case *ShortNode:
  67. // Leaf node
  68. if vnode, ok := node.Value().(*ValueNode); ok {
  69. k := RemTerm(node.Key())
  70. self.Value = vnode.Val()
  71. return k
  72. } else {
  73. k := RemTerm(node.Key())
  74. return append(k, self.key(node.Value())...)
  75. }
  76. case *FullNode:
  77. if node.Value() != nil {
  78. self.Value = node.Value().(*ValueNode).Val()
  79. return []byte{16}
  80. }
  81. for i := 0; i < 16; i++ {
  82. k := self.key(node.branch(byte(i)))
  83. if k != nil {
  84. return append([]byte{byte(i)}, k...)
  85. }
  86. }
  87. }
  88. return nil
  89. }