فهرست منبع

trie: rework and document key encoding

'encode' and 'decode' are meaningless because the code deals with three
encodings. Document the encodings and give a name to each one.
Felix Lange 8 سال پیش
والد
کامیت
f958d7d482
7فایلهای تغییر یافته به همراه121 افزوده شده و 159 حذف شده
  1. 50 64
      trie/encoding.go
  2. 61 86
      trie/encoding_test.go
  3. 1 1
      trie/hasher.go
  4. 2 1
      trie/iterator.go
  5. 2 2
      trie/node.go
  6. 2 2
      trie/proof.go
  7. 3 3
      trie/trie.go

+ 50 - 64
trie/encoding.go

@@ -16,49 +16,54 @@
 
 package trie
 
-func compactEncode(hexSlice []byte) []byte {
+// Trie keys are dealt with in three distinct encodings:
+//
+// KEYBYTES encoding contains the actual key and nothing else. This encoding is the
+// input to most API functions.
+//
+// HEX encoding contains one byte for each nibble of the key and an optional trailing
+// 'terminator' byte of value 0x10 which indicates whether or not the node at the key
+// contains a value. Hex key encoding is used for nodes loaded in memory because it's
+// convenient to access.
+//
+// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix
+// encoding" there) and contains the bytes of the key and a flag. The high nibble of the
+// first byte contains the flag; the lowest bit encoding the oddness of the length and
+// the second-lowest encoding whether the node at the key is a value node. The low nibble
+// of the first byte is zero in the case of an even number of nibbles and the first nibble
+// in the case of an odd number. All remaining nibbles (now an even number) fit properly
+// into the remaining bytes. Compact encoding is used for nodes stored on disk.
+
+func hexToCompact(hex []byte) []byte {
 	terminator := byte(0)
-	if hexSlice[len(hexSlice)-1] == 16 {
+	if hasTerm(hex) {
 		terminator = 1
-		hexSlice = hexSlice[:len(hexSlice)-1]
-	}
-	var (
-		odd    = byte(len(hexSlice) % 2)
-		buflen = len(hexSlice)/2 + 1
-		bi, hi = 0, 0    // indices
-		hs     = byte(0) // shift: flips between 0 and 4
-	)
-	if odd == 0 {
-		bi = 1
-		hs = 4
+		hex = hex[:len(hex)-1]
 	}
-	buf := make([]byte, buflen)
-	buf[0] = terminator<<5 | byte(odd)<<4
-	for bi < len(buf) && hi < len(hexSlice) {
-		buf[bi] |= hexSlice[hi] << hs
-		if hs == 0 {
-			bi++
-		}
-		hi, hs = hi+1, hs^(1<<2)
+	buf := make([]byte, len(hex)/2+1)
+	buf[0] = terminator << 5 // the flag byte
+	if len(hex)&1 == 1 {
+		buf[0] |= 1 << 4 // odd flag
+		buf[0] |= hex[0] // first nibble is contained in the first byte
+		hex = hex[1:]
 	}
+	decodeNibbles(hex, buf[1:])
 	return buf
 }
 
-func compactDecode(str []byte) []byte {
-	base := compactHexDecode(str)
+func compactToHex(compact []byte) []byte {
+	base := keybytesToHex(compact)
 	base = base[:len(base)-1]
+	// apply terminator flag
 	if base[0] >= 2 {
 		base = append(base, 16)
 	}
-	if base[0]%2 == 1 {
-		base = base[1:]
-	} else {
-		base = base[2:]
-	}
-	return base
+	// apply odd flag
+	chop := 2 - base[0]&1
+	return base[chop:]
 }
 
-func compactHexDecode(str []byte) []byte {
+func keybytesToHex(str []byte) []byte {
 	l := len(str)*2 + 1
 	var nibbles = make([]byte, l)
 	for i, b := range str {
@@ -69,35 +74,24 @@ func compactHexDecode(str []byte) []byte {
 	return nibbles
 }
 
-// compactHexEncode encodes a series of nibbles into a byte array
-func compactHexEncode(nibbles []byte) []byte {
-	nl := len(nibbles)
-	if nl == 0 {
-		return nil
-	}
-	if nibbles[nl-1] == 16 {
-		nl--
+// hexToKeybytes turns hex nibbles into key bytes.
+// This can only be used for keys of even length.
+func hexToKeybytes(hex []byte) []byte {
+	if hasTerm(hex) {
+		hex = hex[:len(hex)-1]
 	}
-	l := (nl + 1) / 2
-	var str = make([]byte, l)
-	for i := range str {
-		b := nibbles[i*2] * 16
-		if nl > i*2 {
-			b += nibbles[i*2+1]
-		}
-		str[i] = b
+	if len(hex)&1 != 0 {
+		panic("can't convert hex key of odd length")
 	}
-	return str
+	key := make([]byte, (len(hex)+1)/2)
+	decodeNibbles(hex, key)
+	return key
 }
 
-func decodeCompact(key []byte) []byte {
-	l := len(key) / 2
-	var res = make([]byte, l)
-	for i := 0; i < l; i++ {
-		v1, v0 := key[2*i], key[2*i+1]
-		res[i] = v1*16 + v0
+func decodeNibbles(nibbles []byte, bytes []byte) {
+	for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 {
+		bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1]
 	}
-	return res
 }
 
 // prefixLen returns the length of the common prefix of a and b.
@@ -114,15 +108,7 @@ func prefixLen(a, b []byte) int {
 	return i
 }
 
+// hasTerm returns whether a hex key has the terminator flag.
 func hasTerm(s []byte) bool {
-	return s[len(s)-1] == 16
-}
-
-func remTerm(s []byte) []byte {
-	if hasTerm(s) {
-		b := make([]byte, len(s)-1)
-		copy(b, s)
-		return b
-	}
-	return s
+	return len(s) > 0 && s[len(s)-1] == 16
 }

+ 61 - 86
trie/encoding_test.go

@@ -17,113 +17,88 @@
 package trie
 
 import (
-	"encoding/hex"
+	"bytes"
 	"testing"
-
-	checker "gopkg.in/check.v1"
 )
 
-func TestEncoding(t *testing.T) { checker.TestingT(t) }
-
-type TrieEncodingSuite struct{}
-
-var _ = checker.Suite(&TrieEncodingSuite{})
-
-func (s *TrieEncodingSuite) TestCompactEncode(c *checker.C) {
-	// even compact encode
-	test1 := []byte{1, 2, 3, 4, 5}
-	res1 := compactEncode(test1)
-	c.Assert(res1, checker.DeepEquals, []byte("\x11\x23\x45"))
-
-	// odd compact encode
-	test2 := []byte{0, 1, 2, 3, 4, 5}
-	res2 := compactEncode(test2)
-	c.Assert(res2, checker.DeepEquals, []byte("\x00\x01\x23\x45"))
-
-	//odd terminated compact encode
-	test3 := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
-	res3 := compactEncode(test3)
-	c.Assert(res3, checker.DeepEquals, []byte("\x20\x0f\x1c\xb8"))
-
-	// even terminated compact encode
-	test4 := []byte{15, 1, 12, 11, 8 /*term*/, 16}
-	res4 := compactEncode(test4)
-	c.Assert(res4, checker.DeepEquals, []byte("\x3f\x1c\xb8"))
-}
-
-func (s *TrieEncodingSuite) TestCompactHexDecode(c *checker.C) {
-	exp := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
-	res := compactHexDecode([]byte("verb"))
-	c.Assert(res, checker.DeepEquals, exp)
-}
-
-func (s *TrieEncodingSuite) TestCompactHexEncode(c *checker.C) {
-	exp := []byte("verb")
-	res := compactHexEncode([]byte{7, 6, 6, 5, 7, 2, 6, 2, 16})
-	c.Assert(res, checker.DeepEquals, exp)
-}
-
-func (s *TrieEncodingSuite) TestCompactDecode(c *checker.C) {
-	// odd compact decode
-	exp := []byte{1, 2, 3, 4, 5}
-	res := compactDecode([]byte("\x11\x23\x45"))
-	c.Assert(res, checker.DeepEquals, exp)
-
-	// even compact decode
-	exp = []byte{0, 1, 2, 3, 4, 5}
-	res = compactDecode([]byte("\x00\x01\x23\x45"))
-	c.Assert(res, checker.DeepEquals, exp)
-
-	// even terminated compact decode
-	exp = []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
-	res = compactDecode([]byte("\x20\x0f\x1c\xb8"))
-	c.Assert(res, checker.DeepEquals, exp)
-
-	// even terminated compact decode
-	exp = []byte{15, 1, 12, 11, 8 /*term*/, 16}
-	res = compactDecode([]byte("\x3f\x1c\xb8"))
-	c.Assert(res, checker.DeepEquals, exp)
+func TestHexCompact(t *testing.T) {
+	tests := []struct{ hex, compact []byte }{
+		// empty keys, with and without terminator.
+		{hex: []byte{}, compact: []byte{0x00}},
+		{hex: []byte{16}, compact: []byte{0x20}},
+		// odd length, no terminator
+		{hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}},
+		// even length, no terminator
+		{hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}},
+		// odd length, terminator
+		{hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}},
+		// even length, terminator
+		{hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}},
+	}
+	for _, test := range tests {
+		if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) {
+			t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact)
+		}
+		if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) {
+			t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex)
+		}
+	}
 }
 
-func (s *TrieEncodingSuite) TestDecodeCompact(c *checker.C) {
-	exp, _ := hex.DecodeString("012345")
-	res := decodeCompact([]byte{0, 1, 2, 3, 4, 5})
-	c.Assert(res, checker.DeepEquals, exp)
-
-	exp, _ = hex.DecodeString("012345")
-	res = decodeCompact([]byte{0, 1, 2, 3, 4, 5, 16})
-	c.Assert(res, checker.DeepEquals, exp)
-
-	exp, _ = hex.DecodeString("abcdef")
-	res = decodeCompact([]byte{10, 11, 12, 13, 14, 15})
-	c.Assert(res, checker.DeepEquals, exp)
+func TestHexKeybytes(t *testing.T) {
+	tests := []struct{ key, hexIn, hexOut []byte }{
+		{key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}},
+		{key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}},
+		{
+			key:    []byte{0x12, 0x34, 0x56},
+			hexIn:  []byte{1, 2, 3, 4, 5, 6, 16},
+			hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
+		},
+		{
+			key:    []byte{0x12, 0x34, 0x5},
+			hexIn:  []byte{1, 2, 3, 4, 0, 5, 16},
+			hexOut: []byte{1, 2, 3, 4, 0, 5, 16},
+		},
+		{
+			key:    []byte{0x12, 0x34, 0x56},
+			hexIn:  []byte{1, 2, 3, 4, 5, 6},
+			hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
+		},
+	}
+	for _, test := range tests {
+		if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) {
+			t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut)
+		}
+		if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) {
+			t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key)
+		}
+	}
 }
 
-func BenchmarkCompactEncode(b *testing.B) {
-
-	testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
+func BenchmarkHexToCompact(b *testing.B) {
+	testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
 	for i := 0; i < b.N; i++ {
-		compactEncode(testBytes)
+		hexToCompact(testBytes)
 	}
 }
 
-func BenchmarkCompactDecode(b *testing.B) {
-	testBytes := []byte{0, 15, 1, 12, 11, 8 /*term*/, 16}
+func BenchmarkCompactToHex(b *testing.B) {
+	testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
 	for i := 0; i < b.N; i++ {
-		compactDecode(testBytes)
+		compactToHex(testBytes)
 	}
 }
 
-func BenchmarkCompactHexDecode(b *testing.B) {
+func BenchmarkKeybytesToHex(b *testing.B) {
 	testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
 	for i := 0; i < b.N; i++ {
-		compactHexDecode(testBytes)
+		keybytesToHex(testBytes)
 	}
 }
 
-func BenchmarkDecodeCompact(b *testing.B) {
+func BenchmarkHexToKeybytes(b *testing.B) {
 	testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
 	for i := 0; i < b.N; i++ {
-		decodeCompact(testBytes)
+		hexToKeybytes(testBytes)
 	}
 }

+ 1 - 1
trie/hasher.go

@@ -105,7 +105,7 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err
 	case *shortNode:
 		// Hash the short node's child, caching the newly hashed subtree
 		collapsed, cached := n.copy(), n.copy()
-		collapsed.Key = compactEncode(n.Key)
+		collapsed.Key = hexToCompact(n.Key)
 		cached.Key = common.CopyBytes(n.Key)
 
 		if _, ok := n.Val.(valueNode); !ok {

+ 2 - 1
trie/iterator.go

@@ -19,6 +19,7 @@ package trie
 import (
 	"bytes"
 	"container/heap"
+
 	"github.com/ethereum/go-ethereum/common"
 )
 
@@ -48,7 +49,7 @@ func NewIteratorFromNodeIterator(it NodeIterator) *Iterator {
 func (it *Iterator) Next() bool {
 	for it.nodeIt.Next(true) {
 		if it.nodeIt.Leaf() {
-			it.Key = decodeCompact(it.nodeIt.Path())
+			it.Key = hexToKeybytes(it.nodeIt.Path())
 			it.Value = it.nodeIt.LeafBlob()
 			return true
 		}

+ 2 - 2
trie/node.go

@@ -139,8 +139,8 @@ func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
 		return nil, err
 	}
 	flag := nodeFlag{hash: hash, gen: cachegen}
-	key := compactDecode(kbuf)
-	if key[len(key)-1] == 16 {
+	key := compactToHex(kbuf)
+	if hasTerm(key) {
 		// value node
 		val, _, err := rlp.SplitString(rest)
 		if err != nil {

+ 2 - 2
trie/proof.go

@@ -38,7 +38,7 @@ import (
 // absence of the key.
 func (t *Trie) Prove(key []byte) []rlp.RawValue {
 	// Collect all nodes on the path to key.
-	key = compactHexDecode(key)
+	key = keybytesToHex(key)
 	nodes := []node{}
 	tn := t.root
 	for len(key) > 0 && tn != nil {
@@ -89,7 +89,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
 // returns an error if the proof contains invalid trie nodes or the
 // wrong value.
 func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) {
-	key = compactHexDecode(key)
+	key = keybytesToHex(key)
 	sha := sha3.NewKeccak256()
 	wantHash := rootHash.Bytes()
 	for i, buf := range proof {

+ 3 - 3
trie/trie.go

@@ -144,7 +144,7 @@ func (t *Trie) Get(key []byte) []byte {
 // The value bytes must not be modified by the caller.
 // If a node was not found in the database, a MissingNodeError is returned.
 func (t *Trie) TryGet(key []byte) ([]byte, error) {
-	key = compactHexDecode(key)
+	key = keybytesToHex(key)
 	value, newroot, didResolve, err := t.tryGet(t.root, key, 0)
 	if err == nil && didResolve {
 		t.root = newroot
@@ -211,7 +211,7 @@ func (t *Trie) Update(key, value []byte) {
 //
 // If a node was not found in the database, a MissingNodeError is returned.
 func (t *Trie) TryUpdate(key, value []byte) error {
-	k := compactHexDecode(key)
+	k := keybytesToHex(key)
 	if len(value) != 0 {
 		_, n, err := t.insert(t.root, nil, k, valueNode(value))
 		if err != nil {
@@ -307,7 +307,7 @@ func (t *Trie) Delete(key []byte) {
 // TryDelete removes any existing value for key from the trie.
 // If a node was not found in the database, a MissingNodeError is returned.
 func (t *Trie) TryDelete(key []byte) error {
-	k := compactHexDecode(key)
+	k := keybytesToHex(key)
 	_, n, err := t.delete(t.root, nil, k)
 	if err != nil {
 		return err