trie.go 6.7 KB


  1. package trie
  2. import (
  3. "bytes"
  4. "container/list"
  5. "fmt"
  6. "sync"
  7. "github.com/ethereum/go-ethereum/crypto"
  8. "github.com/ethereum/go-ethereum/ethutil"
  9. )
  10. func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
  11. t2 := New(nil, backend)
  12. it := t1.Iterator()
  13. for it.Next() {
  14. t2.Update(it.Key, it.Value)
  15. }
  16. return bytes.Equal(t2.Hash(), t1.Hash()), t2
  17. }
  18. type Trie struct {
  19. mu sync.Mutex
  20. root Node
  21. roothash []byte
  22. cache *Cache
  23. revisions *list.List
  24. }
  25. func New(root []byte, backend Backend) *Trie {
  26. trie := &Trie{}
  27. trie.revisions = list.New()
  28. trie.roothash = root
  29. trie.cache = NewCache(backend)
  30. if root != nil {
  31. value := ethutil.NewValueFromBytes(trie.cache.Get(root))
  32. trie.root = trie.mknode(value)
  33. }
  34. return trie
  35. }
  36. func (self *Trie) Iterator() *Iterator {
  37. return NewIterator(self)
  38. }
  39. func (self *Trie) Copy() *Trie {
  40. return New(self.roothash, self.cache.backend)
  41. }
  42. // Legacy support
  43. func (self *Trie) Root() []byte { return self.Hash() }
  44. func (self *Trie) Hash() []byte {
  45. var hash []byte
  46. if self.root != nil {
  47. t := self.root.Hash()
  48. if byts, ok := t.([]byte); ok && len(byts) > 0 {
  49. hash = byts
  50. } else {
  51. hash = crypto.Sha3(ethutil.Encode(self.root.RlpData()))
  52. }
  53. } else {
  54. hash = crypto.Sha3(ethutil.Encode(""))
  55. }
  56. if !bytes.Equal(hash, self.roothash) {
  57. self.revisions.PushBack(self.roothash)
  58. self.roothash = hash
  59. }
  60. return hash
  61. }
  62. func (self *Trie) Commit() {
  63. self.mu.Lock()
  64. defer self.mu.Unlock()
  65. // Hash first
  66. self.Hash()
  67. self.cache.Flush()
  68. }
  69. // Reset should only be called if the trie has been hashed
  70. func (self *Trie) Reset() {
  71. self.mu.Lock()
  72. defer self.mu.Unlock()
  73. self.cache.Reset()
  74. if self.revisions.Len() > 0 {
  75. revision := self.revisions.Remove(self.revisions.Back()).([]byte)
  76. self.roothash = revision
  77. }
  78. value := ethutil.NewValueFromBytes(self.cache.Get(self.roothash))
  79. self.root = self.mknode(value)
  80. }
  81. func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) }
  82. func (self *Trie) Update(key, value []byte) Node {
  83. self.mu.Lock()
  84. defer self.mu.Unlock()
  85. k := CompactHexDecode(string(key))
  86. if len(value) != 0 {
  87. self.root = self.insert(self.root, k, &ValueNode{self, value})
  88. } else {
  89. self.root = self.delete(self.root, k)
  90. }
  91. return self.root
  92. }
  93. func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) }
  94. func (self *Trie) Get(key []byte) []byte {
  95. self.mu.Lock()
  96. defer self.mu.Unlock()
  97. k := CompactHexDecode(string(key))
  98. n := self.get(self.root, k)
  99. if n != nil {
  100. return n.(*ValueNode).Val()
  101. }
  102. return nil
  103. }
  104. func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) }
  105. func (self *Trie) Delete(key []byte) Node {
  106. self.mu.Lock()
  107. defer self.mu.Unlock()
  108. k := CompactHexDecode(string(key))
  109. self.root = self.delete(self.root, k)
  110. return self.root
  111. }
  112. func (self *Trie) insert(node Node, key []byte, value Node) Node {
  113. if len(key) == 0 {
  114. return value
  115. }
  116. if node == nil {
  117. return NewShortNode(self, key, value)
  118. }
  119. switch node := node.(type) {
  120. case *ShortNode:
  121. k := node.Key()
  122. cnode := node.Value()
  123. if bytes.Equal(k, key) {
  124. return NewShortNode(self, key, value)
  125. }
  126. var n Node
  127. matchlength := MatchingNibbleLength(key, k)
  128. if matchlength == len(k) {
  129. n = self.insert(cnode, key[matchlength:], value)
  130. } else {
  131. pnode := self.insert(nil, k[matchlength+1:], cnode)
  132. nnode := self.insert(nil, key[matchlength+1:], value)
  133. fulln := NewFullNode(self)
  134. fulln.set(k[matchlength], pnode)
  135. fulln.set(key[matchlength], nnode)
  136. n = fulln
  137. }
  138. if matchlength == 0 {
  139. return n
  140. }
  141. return NewShortNode(self, key[:matchlength], n)
  142. case *FullNode:
  143. cpy := node.Copy().(*FullNode)
  144. cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
  145. return cpy
  146. default:
  147. panic(fmt.Sprintf("%T: invalid node: %v", node, node))
  148. }
  149. }
  150. func (self *Trie) get(node Node, key []byte) Node {
  151. if len(key) == 0 {
  152. return node
  153. }
  154. if node == nil {
  155. return nil
  156. }
  157. switch node := node.(type) {
  158. case *ShortNode:
  159. k := node.Key()
  160. cnode := node.Value()
  161. if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) {
  162. return self.get(cnode, key[len(k):])
  163. }
  164. return nil
  165. case *FullNode:
  166. return self.get(node.branch(key[0]), key[1:])
  167. default:
  168. panic(fmt.Sprintf("%T: invalid node: %v", node, node))
  169. }
  170. }
  171. func (self *Trie) delete(node Node, key []byte) Node {
  172. if len(key) == 0 && node == nil {
  173. return nil
  174. }
  175. switch node := node.(type) {
  176. case *ShortNode:
  177. k := node.Key()
  178. cnode := node.Value()
  179. if bytes.Equal(key, k) {
  180. return nil
  181. } else if bytes.Equal(key[:len(k)], k) {
  182. child := self.delete(cnode, key[len(k):])
  183. var n Node
  184. switch child := child.(type) {
  185. case *ShortNode:
  186. nkey := append(k, child.Key()...)
  187. n = NewShortNode(self, nkey, child.Value())
  188. case *FullNode:
  189. sn := NewShortNode(self, node.Key(), child)
  190. sn.key = node.key
  191. n = sn
  192. }
  193. return n
  194. } else {
  195. return node
  196. }
  197. case *FullNode:
  198. n := node.Copy().(*FullNode)
  199. n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
  200. pos := -1
  201. for i := 0; i < 17; i++ {
  202. if n.branch(byte(i)) != nil {
  203. if pos == -1 {
  204. pos = i
  205. } else {
  206. pos = -2
  207. }
  208. }
  209. }
  210. var nnode Node
  211. if pos == 16 {
  212. nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
  213. } else if pos >= 0 {
  214. cnode := n.branch(byte(pos))
  215. switch cnode := cnode.(type) {
  216. case *ShortNode:
  217. // Stitch keys
  218. k := append([]byte{byte(pos)}, cnode.Key()...)
  219. nnode = NewShortNode(self, k, cnode.Value())
  220. case *FullNode:
  221. nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
  222. }
  223. } else {
  224. nnode = n
  225. }
  226. return nnode
  227. case nil:
  228. return nil
  229. default:
  230. panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key))
  231. }
  232. }
  233. // casting functions and cache storing
  234. func (self *Trie) mknode(value *ethutil.Value) Node {
  235. l := value.Len()
  236. switch l {
  237. case 0:
  238. return nil
  239. case 2:
  240. // A value node may consists of 2 bytes.
  241. if value.Get(0).Len() != 0 {
  242. return NewShortNode(self, CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1)))
  243. }
  244. case 17:
  245. fnode := NewFullNode(self)
  246. for i := 0; i < l; i++ {
  247. fnode.set(byte(i), self.mknode(value.Get(i)))
  248. }
  249. return fnode
  250. case 32:
  251. return &HashNode{value.Bytes()}
  252. }
  253. return &ValueNode{self, value.Bytes()}
  254. }
  255. func (self *Trie) trans(node Node) Node {
  256. switch node := node.(type) {
  257. case *HashNode:
  258. value := ethutil.NewValueFromBytes(self.cache.Get(node.key))
  259. return self.mknode(value)
  260. default:
  261. return node
  262. }
  263. }
  264. func (self *Trie) store(node Node) interface{} {
  265. data := ethutil.Encode(node)
  266. if len(data) >= 32 {
  267. key := crypto.Sha3(data)
  268. self.cache.Put(key, data)
  269. return key
  270. }
  271. return node.RlpData()
  272. }
  273. func (self *Trie) PrintRoot() {
  274. fmt.Println(self.root)
  275. }