Prechádzať zdrojové kódy

Merge pull request #1369 from obscuren/statedb-update-cleanup

core, core/state: throw out intermediate state
Jeffrey Wilcke 10 rokov pred
rodič
commit
9c3db1be1d

+ 2 - 2
core/block_processor.go

@@ -77,7 +77,7 @@ func (self *BlockProcessor) ApplyTransaction(coinbase *state.StateObject, stated
 	}
 
 	// Update the state with pending changes
-	statedb.Update()
+	statedb.SyncIntermediate()
 
 	usedGas.Add(usedGas, gas)
 	receipt := types.NewReceipt(statedb.Root().Bytes(), usedGas)
@@ -243,7 +243,7 @@ func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs st
 
 	// Commit state objects/accounts to a temporary trie (does not save)
 	// used to calculate the state root.
-	state.Update()
+	state.SyncObjects()
 	if header.Root != state.Root() {
 		err = fmt.Errorf("invalid merkle root. received=%x got=%x", header.Root, state.Root())
 		return

+ 2 - 2
core/chain_makers.go

@@ -77,7 +77,7 @@ func (b *BlockGen) AddTx(tx *types.Transaction) {
 	if err != nil {
 		panic(err)
 	}
-	b.statedb.Update()
+	b.statedb.SyncIntermediate()
 	b.header.GasUsed.Add(b.header.GasUsed, gas)
 	receipt := types.NewReceipt(b.statedb.Root().Bytes(), b.header.GasUsed)
 	logs := b.statedb.GetLogs(tx.Hash())
@@ -135,7 +135,7 @@ func GenerateChain(parent *types.Block, db common.Database, n int, gen func(int,
 			gen(i, b)
 		}
 		AccumulateRewards(statedb, h, b.uncles)
-		statedb.Update()
+		statedb.SyncIntermediate()
 		h.Root = statedb.Root()
 		return types.NewBlock(h, b.txs, b.uncles, b.receipts)
 	}

+ 1 - 1
core/genesis.go

@@ -64,7 +64,7 @@ func GenesisBlockForTesting(db common.Database, addr common.Address, balance *bi
 	statedb := state.New(common.Hash{}, db)
 	obj := statedb.GetOrNewStateObject(addr)
 	obj.SetBalance(balance)
-	statedb.Update()
+	statedb.SyncObjects()
 	statedb.Sync()
 	block := types.NewBlock(&types.Header{
 		Difficulty: params.GenesisDifficulty,

+ 0 - 8
core/state/state_object.go

@@ -57,8 +57,6 @@ type StateObject struct {
 	initCode Code
 	// Cached storage (flushed when updated)
 	storage Storage
-	// Temporary prepaid gas, reward after transition
-	prepaid *big.Int
 
 	// Total gas pool is the total amount of gas currently
 	// left if this object is the coinbase. Gas is directly
@@ -77,14 +75,10 @@ func (self *StateObject) Reset() {
 }
 
 func NewStateObject(address common.Address, db common.Database) *StateObject {
-	// This to ensure that it has 20 bytes (and not 0 bytes), thus left or right pad doesn't matter.
-	//address := common.ToAddress(addr)
-
 	object := &StateObject{db: db, address: address, balance: new(big.Int), gasPool: new(big.Int), dirty: true}
 	object.trie = trie.NewSecure((common.Hash{}).Bytes(), db)
 	object.storage = make(Storage)
 	object.gasPool = new(big.Int)
-	object.prepaid = new(big.Int)
 
 	return object
 }
@@ -110,7 +104,6 @@ func NewStateObjectFromBytes(address common.Address, data []byte, db common.Data
 	object.trie = trie.NewSecure(extobject.Root[:], db)
 	object.storage = make(map[string]common.Hash)
 	object.gasPool = new(big.Int)
-	object.prepaid = new(big.Int)
 	object.code, _ = db.Get(extobject.CodeHash)
 
 	return object
@@ -172,7 +165,6 @@ func (self *StateObject) Update() {
 
 		self.setAddr([]byte(key), value)
 	}
-	self.storage = make(Storage)
 }
 
 func (c *StateObject) GetInstr(pc *big.Int) *common.Value {

+ 1 - 1
core/state/state_test.go

@@ -72,7 +72,7 @@ func TestNull(t *testing.T) {
 	//value := common.FromHex("0x823140710bf13990e4500136726d8b55")
 	var value common.Hash
 	state.SetState(address, common.Hash{}, value)
-	state.Update()
+	state.SyncIntermediate()
 	state.Sync()
 	value = state.GetState(address, common.Hash{})
 	if !common.EmptyHash(value) {

+ 23 - 3
core/state/statedb.go

@@ -18,6 +18,7 @@ import (
 type StateDB struct {
 	db   common.Database
 	trie *trie.SecureTrie
+	root common.Hash
 
 	stateObjects map[string]*StateObject
 
@@ -31,7 +32,7 @@ type StateDB struct {
 // Create a new state from a given trie
 func New(root common.Hash, db common.Database) *StateDB {
 	trie := trie.NewSecure(root[:], db)
-	return &StateDB{db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)}
+	return &StateDB{root: root, db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)}
 }
 
 func (self *StateDB) PrintRoot() {
@@ -185,7 +186,7 @@ func (self *StateDB) DeleteStateObject(stateObject *StateObject) {
 	addr := stateObject.Address()
 	self.trie.Delete(addr[:])
 
-	delete(self.stateObjects, addr.Str())
+	//delete(self.stateObjects, addr.Str())
 }
 
 // Retrieve a state object given my the address. Nil if not found
@@ -323,7 +324,8 @@ func (self *StateDB) Refunds() *big.Int {
 	return self.refund
 }
 
-func (self *StateDB) Update() {
+// SyncIntermediate updates the intermediate state and all mid steps
+func (self *StateDB) SyncIntermediate() {
 	self.refund = new(big.Int)
 
 	for _, stateObject := range self.stateObjects {
@@ -340,6 +342,24 @@ func (self *StateDB) Update() {
 	}
 }
 
+// SyncObjects syncs the changed objects to the trie
+func (self *StateDB) SyncObjects() {
+	self.trie = trie.NewSecure(self.root[:], self.db)
+
+	self.refund = new(big.Int)
+
+	for _, stateObject := range self.stateObjects {
+		if stateObject.remove {
+			self.DeleteStateObject(stateObject)
+		} else {
+			stateObject.Update()
+
+			self.UpdateStateObject(stateObject)
+		}
+		stateObject.dirty = false
+	}
+}
+
 // Debug stuff
 func (self *StateDB) CreateOutputForDiff() {
 	for _, stateObject := range self.stateObjects {

+ 1 - 1
miner/worker.go

@@ -453,7 +453,7 @@ func (self *worker) commitNewWork() {
 	if atomic.LoadInt32(&self.mining) == 1 {
 		// commit state root after all state transitions.
 		core.AccumulateRewards(self.current.state, header, uncles)
-		current.state.Update()
+		current.state.SyncObjects()
 		self.current.state.Sync()
 		header.Root = current.state.Root()
 	}

+ 1 - 1
tests/block_test_util.go

@@ -215,7 +215,7 @@ func (t *BlockTest) InsertPreState(ethereum *eth.Ethereum) (*state.StateDB, erro
 		}
 	}
 	// sync objects to trie
-	statedb.Update()
+	statedb.SyncObjects()
 	// sync trie to disk
 	statedb.Sync()
 

+ 1 - 1
tests/state_test_util.go

@@ -175,7 +175,7 @@ func RunState(statedb *state.StateDB, env, tx map[string]string) ([]byte, state.
 	if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || state.IsGasLimitErr(err) {
 		statedb.Set(snapshot)
 	}
-	statedb.Update()
+	statedb.SyncObjects()
 
 	return ret, vmenv.state.Logs(), vmenv.Gas, err
 }

+ 9 - 8
trie/fullnode.go

@@ -1,17 +1,16 @@
 package trie
 
-import "fmt"
-
 type FullNode struct {
 	trie  *Trie
 	nodes [17]Node
+	dirty bool
 }
 
 func NewFullNode(t *Trie) *FullNode {
 	return &FullNode{trie: t}
 }
 
-func (self *FullNode) Dirty() bool { return true }
+func (self *FullNode) Dirty() bool { return self.dirty }
 func (self *FullNode) Value() Node {
 	self.nodes[16] = self.trie.trans(self.nodes[16])
 	return self.nodes[16]
@@ -24,9 +23,10 @@ func (self *FullNode) Copy(t *Trie) Node {
 	nnode := NewFullNode(t)
 	for i, node := range self.nodes {
 		if node != nil {
-			nnode.nodes[i] = node.Copy(t)
+			nnode.nodes[i] = node
 		}
 	}
+	nnode.dirty = true
 
 	return nnode
 }
@@ -60,11 +60,8 @@ func (self *FullNode) RlpData() interface{} {
 }
 
 func (self *FullNode) set(k byte, value Node) {
-	if _, ok := value.(*ValueNode); ok && k != 16 {
-		fmt.Println(value, k)
-	}
-
 	self.nodes[int(k)] = value
+	self.dirty = true
 }
 
 func (self *FullNode) branch(i byte) Node {
@@ -75,3 +72,7 @@ func (self *FullNode) branch(i byte) Node {
 	}
 	return nil
 }
+
+func (self *FullNode) setDirty(dirty bool) {
+	self.dirty = dirty
+}

+ 8 - 3
trie/hashnode.go

@@ -3,12 +3,13 @@ package trie
 import "github.com/ethereum/go-ethereum/common"
 
 type HashNode struct {
-	key  []byte
-	trie *Trie
+	key   []byte
+	trie  *Trie
+	dirty bool
 }
 
 func NewHash(key []byte, trie *Trie) *HashNode {
-	return &HashNode{key, trie}
+	return &HashNode{key, trie, false}
 }
 
 func (self *HashNode) RlpData() interface{} {
@@ -19,6 +20,10 @@ func (self *HashNode) Hash() interface{} {
 	return self.key
 }
 
+func (self *HashNode) setDirty(dirty bool) {
+	self.dirty = dirty
+}
+
 // These methods will never be called but we have to satisfy Node interface
 func (self *HashNode) Value() Node       { return nil }
 func (self *HashNode) Dirty() bool       { return true }

+ 1 - 0
trie/node.go

@@ -11,6 +11,7 @@ type Node interface {
 	fstring(string) string
 	Hash() interface{}
 	RlpData() interface{}
+	setDirty(dirty bool)
 }
 
 // Value node

+ 9 - 3
trie/shortnode.go

@@ -6,20 +6,22 @@ type ShortNode struct {
 	trie  *Trie
 	key   []byte
 	value Node
+	dirty bool
 }
 
 func NewShortNode(t *Trie, key []byte, value Node) *ShortNode {
-	return &ShortNode{t, []byte(CompactEncode(key)), value}
+	return &ShortNode{t, []byte(CompactEncode(key)), value, false}
 }
 func (self *ShortNode) Value() Node {
 	self.value = self.trie.trans(self.value)
 
 	return self.value
 }
-func (self *ShortNode) Dirty() bool { return true }
+func (self *ShortNode) Dirty() bool { return self.dirty }
 func (self *ShortNode) Copy(t *Trie) Node {
-	node := &ShortNode{t, nil, self.value.Copy(t)}
+	node := &ShortNode{t, nil, self.value.Copy(t), self.dirty}
 	node.key = common.CopyBytes(self.key)
+	node.dirty = true
 	return node
 }
 
@@ -33,3 +35,7 @@ func (self *ShortNode) Hash() interface{} {
 func (self *ShortNode) Key() []byte {
 	return CompactDecode(string(self.key))
 }
+
+func (self *ShortNode) setDirty(dirty bool) {
+	self.dirty = dirty
+}

+ 29 - 8
trie/trie.go

@@ -117,7 +117,9 @@ func (self *Trie) Update(key, value []byte) Node {
 	k := CompactHexDecode(string(key))
 
 	if len(value) != 0 {
-		self.root = self.insert(self.root, k, &ValueNode{self, value})
+		node := NewValueNode(self, value)
+		node.dirty = true
+		self.root = self.insert(self.root, k, node)
 	} else {
 		self.root = self.delete(self.root, k)
 	}
@@ -157,7 +159,9 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
 	}
 
 	if node == nil {
-		return NewShortNode(self, key, value)
+		node := NewShortNode(self, key, value)
+		node.dirty = true
+		return node
 	}
 
 	switch node := node.(type) {
@@ -165,7 +169,10 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
 		k := node.Key()
 		cnode := node.Value()
 		if bytes.Equal(k, key) {
-			return NewShortNode(self, key, value)
+			node := NewShortNode(self, key, value)
+			node.dirty = true
+			return node
+
 		}
 
 		var n Node
@@ -176,6 +183,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
 			pnode := self.insert(nil, k[matchlength+1:], cnode)
 			nnode := self.insert(nil, key[matchlength+1:], value)
 			fulln := NewFullNode(self)
+			fulln.dirty = true
 			fulln.set(k[matchlength], pnode)
 			fulln.set(key[matchlength], nnode)
 			n = fulln
@@ -184,11 +192,14 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
 			return n
 		}
 
-		return NewShortNode(self, key[:matchlength], n)
+		snode := NewShortNode(self, key[:matchlength], n)
+		snode.dirty = true
+		return snode
 
 	case *FullNode:
 		cpy := node.Copy(self).(*FullNode)
 		cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
+		cpy.dirty = true
 
 		return cpy
 
@@ -242,8 +253,10 @@ func (self *Trie) delete(node Node, key []byte) Node {
 			case *ShortNode:
 				nkey := append(k, child.Key()...)
 				n = NewShortNode(self, nkey, child.Value())
+				n.(*ShortNode).dirty = true
 			case *FullNode:
 				sn := NewShortNode(self, node.Key(), child)
+				sn.dirty = true
 				sn.key = node.key
 				n = sn
 			}
@@ -256,6 +269,7 @@ func (self *Trie) delete(node Node, key []byte) Node {
 	case *FullNode:
 		n := node.Copy(self).(*FullNode)
 		n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
+		n.dirty = true
 
 		pos := -1
 		for i := 0; i < 17; i++ {
@@ -271,6 +285,7 @@ func (self *Trie) delete(node Node, key []byte) Node {
 		var nnode Node
 		if pos == 16 {
 			nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
+			nnode.(*ShortNode).dirty = true
 		} else if pos >= 0 {
 			cnode := n.branch(byte(pos))
 			switch cnode := cnode.(type) {
@@ -278,8 +293,10 @@ func (self *Trie) delete(node Node, key []byte) Node {
 				// Stitch keys
 				k := append([]byte{byte(pos)}, cnode.Key()...)
 				nnode = NewShortNode(self, k, cnode.Value())
+				nnode.(*ShortNode).dirty = true
 			case *FullNode:
 				nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
+				nnode.(*ShortNode).dirty = true
 			}
 		} else {
 			nnode = n
@@ -304,7 +321,7 @@ func (self *Trie) mknode(value *common.Value) Node {
 		if value.Get(0).Len() != 0 {
 			key := CompactDecode(string(value.Get(0).Bytes()))
 			if key[len(key)-1] == 16 {
-				return NewShortNode(self, key, &ValueNode{self, value.Get(1).Bytes()})
+				return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes()))
 			} else {
 				return NewShortNode(self, key, self.mknode(value.Get(1)))
 			}
@@ -318,10 +335,10 @@ func (self *Trie) mknode(value *common.Value) Node {
 			return fnode
 		}
 	case 32:
-		return &HashNode{value.Bytes(), self}
+		return NewHash(value.Bytes(), self)
 	}
 
-	return &ValueNode{self, value.Bytes()}
+	return NewValueNode(self, value.Bytes())
 }
 
 func (self *Trie) trans(node Node) Node {
@@ -338,7 +355,11 @@ func (self *Trie) store(node Node) interface{} {
 	data := common.Encode(node)
 	if len(data) >= 32 {
 		key := crypto.Sha3(data)
-		self.cache.Put(key, data)
+		if node.Dirty() {
+			//fmt.Println("save", node)
+			//fmt.Println()
+			self.cache.Put(key, data)
+		}
 
 		return key
 	}

+ 1 - 1
trie/trie_test.go

@@ -152,7 +152,7 @@ func TestReplication(t *testing.T) {
 	}
 	trie.Commit()
 
-	trie2 := New(trie.roothash, trie.cache.backend)
+	trie2 := New(trie.Root(), trie.cache.backend)
 	if string(trie2.GetString("horse")) != "stallion" {
 		t.Error("expected to have horse => stallion")
 	}

+ 17 - 6
trie/valuenode.go

@@ -3,13 +3,24 @@ package trie
 import "github.com/ethereum/go-ethereum/common"
 
 type ValueNode struct {
-	trie *Trie
-	data []byte
+	trie  *Trie
+	data  []byte
+	dirty bool
 }
 
-func (self *ValueNode) Value() Node          { return self } // Best not to call :-)
-func (self *ValueNode) Val() []byte          { return self.data }
-func (self *ValueNode) Dirty() bool          { return true }
-func (self *ValueNode) Copy(t *Trie) Node    { return &ValueNode{t, common.CopyBytes(self.data)} }
+func NewValueNode(trie *Trie, data []byte) *ValueNode {
+	return &ValueNode{trie, data, false}
+}
+
+func (self *ValueNode) Value() Node { return self } // Best not to call :-)
+func (self *ValueNode) Val() []byte { return self.data }
+func (self *ValueNode) Dirty() bool { return self.dirty }
+func (self *ValueNode) Copy(t *Trie) Node {
+	return &ValueNode{t, common.CopyBytes(self.data), self.dirty}
+}
 func (self *ValueNode) RlpData() interface{} { return self.data }
 func (self *ValueNode) Hash() interface{}    { return self.data }
+
+func (self *ValueNode) setDirty(dirty bool) {
+	self.dirty = dirty
+}