Browse Source

core/vm: EIP-2315, JUMPSUB for the EVM (#20619)

* core/vm: implement EIP 2315, subroutines for the EVM

* core/vm: eip 2315 - lintfix + check jump dest validity + check ret stack size constraints

  logger: markdown-friendly traces, validate jumpdest, more testcase, correct opcodes

* core/vm: update subroutines acc to eip: disallow walk-into

* core/vm/eips: gas cost changes for subroutines

* core/vm: update opcodes for EIP-2315

* core/vm: define RETURNSUB as a 'jumping' operation + review concerns

Co-authored-by: Martin Holst Swende <martin@swende.se>
Greg Colvin 5 years ago
parent
commit
cd57d5cd38

+ 25 - 7
core/vm/contract.go

@@ -83,7 +83,7 @@ func NewContract(caller ContractRef, object ContractRef, value *big.Int, gas uin
 
 func (c *Contract) validJumpdest(dest *big.Int) bool {
 	udest := dest.Uint64()
-	// PC cannot go beyond len(code) and certainly can't be bigger than 63bits.
+	// PC cannot go beyond len(code) and certainly can't be bigger than 63 bits.
 	// Don't bother checking for JUMPDEST in that case.
 	if dest.BitLen() >= 63 || udest >= uint64(len(c.Code)) {
 		return false
@@ -92,16 +92,32 @@ func (c *Contract) validJumpdest(dest *big.Int) bool {
 	if OpCode(c.Code[udest]) != JUMPDEST {
 		return false
 	}
-	// Do we have it locally already?
-	if c.analysis != nil {
-		return c.analysis.codeSegment(udest)
+	return c.isCode(udest)
+}
+
+func (c *Contract) validJumpSubdest(udest uint64) bool {
+	// PC cannot go beyond len(code) and certainly can't be bigger than 63 bits.
+	// Don't bother checking for BEGINSUB in that case.
+	if int64(udest) < 0 || udest >= uint64(len(c.Code)) {
+		return false
+	}
+	// Only BEGINSUBs allowed for destinations
+	if OpCode(c.Code[udest]) != BEGINSUB {
+		return false
 	}
-	// If we have the code hash (but no analysis), we should look into the
-	// parent analysis map and see if the analysis has been made previously
+	return c.isCode(udest)
+}
+
+// isCode returns true if the provided PC location is an actual opcode, as
+// opposed to a data-segment following a PUSHN operation.
+func (c *Contract) isCode(udest uint64) bool {
+	// Do we have a contract hash already?
 	if c.CodeHash != (common.Hash{}) {
+		// Does parent context have the analysis?
 		analysis, exist := c.jumpdests[c.CodeHash]
 		if !exist {
 			// Do the analysis and save in parent context
+			// We do not need to store it in c.analysis
 			analysis = codeBitmap(c.Code)
 			c.jumpdests[c.CodeHash] = analysis
 		}
@@ -113,7 +129,9 @@ func (c *Contract) validJumpdest(dest *big.Int) bool {
 	// in state trie. In that case, we do an analysis, and save it locally, so
 	// we don't have to recalculate it for every JUMP instruction in the execution
 	// However, we don't save it within the parent context
-	c.analysis = codeBitmap(c.Code)
+	if c.analysis == nil {
+		c.analysis = codeBitmap(c.Code)
+	}
 	return c.analysis.codeSegment(udest)
 }
 

+ 33 - 0
core/vm/eips.go

@@ -33,6 +33,8 @@ func EnableEIP(eipNum int, jt *JumpTable) error {
 		enable1884(jt)
 	case 1344:
 		enable1344(jt)
+	case 2315:
+		enable2315(jt)
 	default:
 		return fmt.Errorf("undefined eip %d", eipNum)
 	}
@@ -91,3 +93,34 @@ func enable2200(jt *JumpTable) {
 	jt[SLOAD].constantGas = params.SloadGasEIP2200
 	jt[SSTORE].dynamicGas = gasSStoreEIP2200
 }
+
+// enable2315 applies EIP-2315 (Simple Subroutines)
+// - Adds opcodes that jump to and return from subroutines
+func enable2315(jt *JumpTable) {
+	// New opcode
+	jt[BEGINSUB] = operation{
+		execute:     opBeginSub,
+		constantGas: GasQuickStep,
+		minStack:    minStack(0, 0),
+		maxStack:    maxStack(0, 0),
+		valid:       true,
+	}
+	// New opcode
+	jt[JUMPSUB] = operation{
+		execute:     opJumpSub,
+		constantGas: GasSlowStep,
+		minStack:    minStack(1, 0),
+		maxStack:    maxStack(1, 0),
+		jumps:       true,
+		valid:       true,
+	}
+	// New opcode
+	jt[RETURNSUB] = operation{
+		execute:     opReturnSub,
+		constantGas: GasFastStep,
+		minStack:    minStack(0, 0),
+		maxStack:    maxStack(0, 0),
+		valid:       true,
+		jumps:       true,
+	}
+}

+ 5 - 0
core/vm/errors.go

@@ -23,6 +23,9 @@ import (
 
 // List evm execution errors
 var (
+	// ErrInvalidSubroutineEntry means that a BEGINSUB was reached via iteration,
+	// as opposed to from a JUMPSUB instruction
+	ErrInvalidSubroutineEntry   = errors.New("invalid subroutine entry")
 	ErrOutOfGas                 = errors.New("out of gas")
 	ErrCodeStoreOutOfGas        = errors.New("contract creation code storage out of gas")
 	ErrDepth                    = errors.New("max call depth exceeded")
@@ -34,6 +37,8 @@ var (
 	ErrWriteProtection          = errors.New("write protection")
 	ErrReturnDataOutOfBounds    = errors.New("return data out of bounds")
 	ErrGasUintOverflow          = errors.New("gas uint64 overflow")
+	ErrInvalidRetsub            = errors.New("invalid retsub")
+	ErrReturnStackExceeded      = errors.New("return stack limit reached")
 )
 
 // ErrStackUnderflow wraps an evm error when the items on the stack less

+ 14 - 0
core/vm/gen_structlog.go

@@ -23,6 +23,7 @@ func (s StructLog) MarshalJSON() ([]byte, error) {
 		Memory        hexutil.Bytes               `json:"memory"`
 		MemorySize    int                         `json:"memSize"`
 		Stack         []*math.HexOrDecimal256     `json:"stack"`
+		ReturnStack   []math.HexOrDecimal64       `json:"returnStack"`
 		Storage       map[common.Hash]common.Hash `json:"-"`
 		Depth         int                         `json:"depth"`
 		RefundCounter uint64                      `json:"refund"`
@@ -43,6 +44,12 @@ func (s StructLog) MarshalJSON() ([]byte, error) {
 			enc.Stack[k] = (*math.HexOrDecimal256)(v)
 		}
 	}
+	if s.ReturnStack != nil {
+		enc.ReturnStack = make([]math.HexOrDecimal64, len(s.ReturnStack))
+		for k, v := range s.ReturnStack {
+			enc.ReturnStack[k] = math.HexOrDecimal64(v)
+		}
+	}
 	enc.Storage = s.Storage
 	enc.Depth = s.Depth
 	enc.RefundCounter = s.RefundCounter
@@ -62,6 +69,7 @@ func (s *StructLog) UnmarshalJSON(input []byte) error {
 		Memory        *hexutil.Bytes              `json:"memory"`
 		MemorySize    *int                        `json:"memSize"`
 		Stack         []*math.HexOrDecimal256     `json:"stack"`
+		ReturnStack   []math.HexOrDecimal64       `json:"returnStack"`
 		Storage       map[common.Hash]common.Hash `json:"-"`
 		Depth         *int                        `json:"depth"`
 		RefundCounter *uint64                     `json:"refund"`
@@ -95,6 +103,12 @@ func (s *StructLog) UnmarshalJSON(input []byte) error {
 			s.Stack[k] = (*big.Int)(v)
 		}
 	}
+	if dec.ReturnStack != nil {
+		s.ReturnStack = make([]uint64, len(dec.ReturnStack))
+		for k, v := range dec.ReturnStack {
+			s.ReturnStack[k] = uint64(v)
+		}
+	}
 	if dec.Storage != nil {
 		s.Storage = dec.Storage
 	}

+ 33 - 0
core/vm/instructions.go

@@ -664,6 +664,39 @@ func opJumpdest(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) (
 	return nil, nil
 }
 
+func opBeginSub(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) {
+	return nil, ErrInvalidSubroutineEntry
+}
+
+func opJumpSub(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) {
+	if len(callContext.rstack.data) >= 1023 {
+		return nil, ErrReturnStackExceeded
+	}
+	pos := callContext.stack.pop()
+	if !pos.IsUint64() {
+		return nil, ErrInvalidJump
+	}
+	posU64 := pos.Uint64()
+	if !callContext.contract.validJumpSubdest(posU64) {
+		return nil, ErrInvalidJump
+	}
+	callContext.rstack.push(*pc)
+	*pc = posU64 + 1
+	interpreter.intPool.put(pos)
+	return nil, nil
+}
+
+func opReturnSub(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) {
+	if len(callContext.rstack.data) == 0 {
+		return nil, ErrInvalidRetsub
+	}
+	// Other than the check that the return stack is not empty, there is no
+	// need to validate the pc from 'returns', since we only ever push valid
+	//values onto it via jumpsub.
+	*pc = callContext.rstack.pop() + 1
+	return nil, nil
+}
+
 func opPc(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) {
 	callContext.stack.push(interpreter.intPool.get().SetUint64(*pc))
 	return nil, nil

+ 16 - 15
core/vm/instructions_test.go

@@ -94,6 +94,7 @@ func testTwoOperandOp(t *testing.T, tests []TwoOperandTestcase, opFn executionFu
 	var (
 		env            = NewEVM(Context{}, nil, params.TestChainConfig, Config{})
 		stack          = newstack()
+		rstack         = newReturnStack()
 		pc             = uint64(0)
 		evmInterpreter = env.interpreter.(*EVMInterpreter)
 	)
@@ -109,7 +110,7 @@ func testTwoOperandOp(t *testing.T, tests []TwoOperandTestcase, opFn executionFu
 		expected := new(big.Int).SetBytes(common.Hex2Bytes(test.Expected))
 		stack.push(x)
 		stack.push(y)
-		opFn(&pc, evmInterpreter, &callCtx{nil, stack, nil})
+		opFn(&pc, evmInterpreter, &callCtx{nil, stack, rstack, nil})
 		actual := stack.pop()
 
 		if actual.Cmp(expected) != 0 {
@@ -211,10 +212,10 @@ func TestSAR(t *testing.T) {
 // getResult is a convenience function to generate the expected values
 func getResult(args []*twoOperandParams, opFn executionFunc) []TwoOperandTestcase {
 	var (
-		env         = NewEVM(Context{}, nil, params.TestChainConfig, Config{})
-		stack       = newstack()
-		pc          = uint64(0)
-		interpreter = env.interpreter.(*EVMInterpreter)
+		env           = NewEVM(Context{}, nil, params.TestChainConfig, Config{})
+		stack, rstack = newstack(), newReturnStack()
+		pc            = uint64(0)
+		interpreter   = env.interpreter.(*EVMInterpreter)
 	)
 	interpreter.intPool = poolOfIntPools.get()
 	result := make([]TwoOperandTestcase, len(args))
@@ -223,7 +224,7 @@ func getResult(args []*twoOperandParams, opFn executionFunc) []TwoOperandTestcas
 		y := new(big.Int).SetBytes(common.Hex2Bytes(param.y))
 		stack.push(x)
 		stack.push(y)
-		opFn(&pc, interpreter, &callCtx{nil, stack, nil})
+		opFn(&pc, interpreter, &callCtx{nil, stack, rstack, nil})
 		actual := stack.pop()
 		result[i] = TwoOperandTestcase{param.x, param.y, fmt.Sprintf("%064x", actual)}
 	}
@@ -263,7 +264,7 @@ func TestJsonTestcases(t *testing.T) {
 func opBenchmark(bench *testing.B, op executionFunc, args ...string) {
 	var (
 		env            = NewEVM(Context{}, nil, params.TestChainConfig, Config{})
-		stack          = newstack()
+		stack, rstack  = newstack(), newReturnStack()
 		evmInterpreter = NewEVMInterpreter(env, env.vmConfig)
 	)
 
@@ -281,7 +282,7 @@ func opBenchmark(bench *testing.B, op executionFunc, args ...string) {
 			a := new(big.Int).SetBytes(arg)
 			stack.push(a)
 		}
-		op(&pc, evmInterpreter, &callCtx{nil, stack, nil})
+		op(&pc, evmInterpreter, &callCtx{nil, stack, rstack, nil})
 		stack.pop()
 	}
 	poolOfIntPools.put(evmInterpreter.intPool)
@@ -498,7 +499,7 @@ func BenchmarkOpIsZero(b *testing.B) {
 func TestOpMstore(t *testing.T) {
 	var (
 		env            = NewEVM(Context{}, nil, params.TestChainConfig, Config{})
-		stack          = newstack()
+		stack, rstack  = newstack(), newReturnStack()
 		mem            = NewMemory()
 		evmInterpreter = NewEVMInterpreter(env, env.vmConfig)
 	)
@@ -509,12 +510,12 @@ func TestOpMstore(t *testing.T) {
 	pc := uint64(0)
 	v := "abcdef00000000000000abba000000000deaf000000c0de00100000000133700"
 	stack.pushN(new(big.Int).SetBytes(common.Hex2Bytes(v)), big.NewInt(0))
-	opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil})
+	opMstore(&pc, evmInterpreter, &callCtx{mem, stack, rstack, nil})
 	if got := common.Bytes2Hex(mem.GetCopy(0, 32)); got != v {
 		t.Fatalf("Mstore fail, got %v, expected %v", got, v)
 	}
 	stack.pushN(big.NewInt(0x1), big.NewInt(0))
-	opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil})
+	opMstore(&pc, evmInterpreter, &callCtx{mem, stack, rstack, nil})
 	if common.Bytes2Hex(mem.GetCopy(0, 32)) != "0000000000000000000000000000000000000000000000000000000000000001" {
 		t.Fatalf("Mstore failed to overwrite previous value")
 	}
@@ -524,7 +525,7 @@ func TestOpMstore(t *testing.T) {
 func BenchmarkOpMstore(bench *testing.B) {
 	var (
 		env            = NewEVM(Context{}, nil, params.TestChainConfig, Config{})
-		stack          = newstack()
+		stack, rstack  = newstack(), newReturnStack()
 		mem            = NewMemory()
 		evmInterpreter = NewEVMInterpreter(env, env.vmConfig)
 	)
@@ -539,7 +540,7 @@ func BenchmarkOpMstore(bench *testing.B) {
 	bench.ResetTimer()
 	for i := 0; i < bench.N; i++ {
 		stack.pushN(value, memStart)
-		opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil})
+		opMstore(&pc, evmInterpreter, &callCtx{mem, stack, rstack, nil})
 	}
 	poolOfIntPools.put(evmInterpreter.intPool)
 }
@@ -547,7 +548,7 @@ func BenchmarkOpMstore(bench *testing.B) {
 func BenchmarkOpSHA3(bench *testing.B) {
 	var (
 		env            = NewEVM(Context{}, nil, params.TestChainConfig, Config{})
-		stack          = newstack()
+		stack, rstack  = newstack(), newReturnStack()
 		mem            = NewMemory()
 		evmInterpreter = NewEVMInterpreter(env, env.vmConfig)
 	)
@@ -560,7 +561,7 @@ func BenchmarkOpSHA3(bench *testing.B) {
 	bench.ResetTimer()
 	for i := 0; i < bench.N; i++ {
 		stack.pushN(big.NewInt(32), start)
-		opSha3(&pc, evmInterpreter, &callCtx{mem, stack, nil})
+		opSha3(&pc, evmInterpreter, &callCtx{mem, stack, rstack, nil})
 	}
 	poolOfIntPools.put(evmInterpreter.intPool)
 }

+ 9 - 6
core/vm/interpreter.go

@@ -67,6 +67,7 @@ type Interpreter interface {
 type callCtx struct {
 	memory   *Memory
 	stack    *Stack
+	rstack   *ReturnStack
 	contract *Contract
 }
 
@@ -167,12 +168,14 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) (
 	}
 
 	var (
-		op          OpCode        // current opcode
-		mem         = NewMemory() // bound memory
-		stack       = newstack()  // local stack
+		op          OpCode             // current opcode
+		mem         = NewMemory()      // bound memory
+		stack       = newstack()       // local stack
+		returns     = newReturnStack() // local returns stack
 		callContext = &callCtx{
 			memory:   mem,
 			stack:    stack,
+			rstack:   returns,
 			contract: contract,
 		}
 		// For optimisation reason we're using uint64 as the program counter.
@@ -195,9 +198,9 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) (
 		defer func() {
 			if err != nil {
 				if !logged {
-					in.cfg.Tracer.CaptureState(in.evm, pcCopy, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err)
+					in.cfg.Tracer.CaptureState(in.evm, pcCopy, op, gasCopy, cost, mem, stack, returns, contract, in.evm.depth, err)
 				} else {
-					in.cfg.Tracer.CaptureFault(in.evm, pcCopy, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err)
+					in.cfg.Tracer.CaptureFault(in.evm, pcCopy, op, gasCopy, cost, mem, stack, returns, contract, in.evm.depth, err)
 				}
 			}
 		}()
@@ -279,7 +282,7 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) (
 		}
 
 		if in.cfg.Debug {
-			in.cfg.Tracer.CaptureState(in.evm, pc, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err)
+			in.cfg.Tracer.CaptureState(in.evm, pc, op, gasCopy, cost, mem, stack, returns, contract, in.evm.depth, err)
 			logged = true
 		}
 

+ 95 - 5
core/vm/logger.go

@@ -22,6 +22,7 @@ import (
 	"fmt"
 	"io"
 	"math/big"
+	"strings"
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
@@ -66,6 +67,7 @@ type StructLog struct {
 	Memory        []byte                      `json:"memory"`
 	MemorySize    int                         `json:"memSize"`
 	Stack         []*big.Int                  `json:"stack"`
+	ReturnStack   []uint64                    `json:"returnStack"`
 	Storage       map[common.Hash]common.Hash `json:"-"`
 	Depth         int                         `json:"depth"`
 	RefundCounter uint64                      `json:"refund"`
@@ -75,6 +77,7 @@ type StructLog struct {
 // overrides for gencodec
 type structLogMarshaling struct {
 	Stack       []*math.HexOrDecimal256
+	ReturnStack []math.HexOrDecimal64
 	Gas         math.HexOrDecimal64
 	GasCost     math.HexOrDecimal64
 	Memory      hexutil.Bytes
@@ -102,8 +105,8 @@ func (s *StructLog) ErrorString() string {
 // if you need to retain them beyond the current call.
 type Tracer interface {
 	CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error
-	CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error
-	CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error
+	CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, contract *Contract, depth int, err error) error
+	CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, contract *Contract, depth int, err error) error
 	CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error
 }
 
@@ -140,7 +143,7 @@ func (l *StructLogger) CaptureStart(from common.Address, to common.Address, crea
 // CaptureState logs a new structured log message and pushes it out to the environment
 //
 // CaptureState also tracks SSTORE ops to track dirty values.
-func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error {
+func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, contract *Contract, depth int, err error) error {
 	// check if already accumulated the specified number of logs
 	if l.cfg.Limit != 0 && l.cfg.Limit <= len(l.logs) {
 		return errTraceLimitReached
@@ -180,8 +183,13 @@ func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost ui
 	if !l.cfg.DisableStorage {
 		storage = l.changedValues[contract.Address()].Copy()
 	}
+	var rstack []uint64
+	if !l.cfg.DisableStack && rStack != nil {
+		rstck := make([]uint64, len(rStack.data))
+		copy(rstck, rStack.data)
+	}
 	// create a new snapshot of the EVM.
-	log := StructLog{pc, op, gas, cost, mem, memory.Len(), stck, storage, depth, env.StateDB.GetRefund(), err}
+	log := StructLog{pc, op, gas, cost, mem, memory.Len(), stck, rstack, storage, depth, env.StateDB.GetRefund(), err}
 
 	l.logs = append(l.logs, log)
 	return nil
@@ -189,7 +197,7 @@ func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost ui
 
 // CaptureFault implements the Tracer interface to trace an execution fault
 // while running an opcode.
-func (l *StructLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error {
+func (l *StructLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, contract *Contract, depth int, err error) error {
 	return nil
 }
 
@@ -230,6 +238,12 @@ func WriteTrace(writer io.Writer, logs []StructLog) {
 				fmt.Fprintf(writer, "%08d  %x\n", len(log.Stack)-i-1, math.PaddedBigBytes(log.Stack[i], 32))
 			}
 		}
+		if len(log.ReturnStack) > 0 {
+			fmt.Fprintln(writer, "ReturnStack:")
+			for i := len(log.Stack) - 1; i >= 0; i-- {
+				fmt.Fprintf(writer, "%08d  0x%x (%d)\n", len(log.Stack)-i-1, log.ReturnStack[i], log.ReturnStack[i])
+			}
+		}
 		if len(log.Memory) > 0 {
 			fmt.Fprintln(writer, "Memory:")
 			fmt.Fprint(writer, hex.Dump(log.Memory))
@@ -257,3 +271,79 @@ func WriteLogs(writer io.Writer, logs []*types.Log) {
 		fmt.Fprintln(writer)
 	}
 }
+
+type mdLogger struct {
+	out io.Writer
+	cfg *LogConfig
+}
+
+// NewMarkdownLogger creates a logger which outputs information in a format adapted
+// for human readability, and is also a valid markdown table
+func NewMarkdownLogger(cfg *LogConfig, writer io.Writer) *mdLogger {
+	l := &mdLogger{writer, cfg}
+	if l.cfg == nil {
+		l.cfg = &LogConfig{}
+	}
+	return l
+}
+
+func (t *mdLogger) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error {
+	if !create {
+		fmt.Fprintf(t.out, "From: `%v`\nTo: `%v`\nData: `0x%x`\nGas: `%d`\nValue `%v` wei\n",
+			from.String(), to.String(),
+			input, gas, value)
+	} else {
+		fmt.Fprintf(t.out, "From: `%v`\nCreate at: `%v`\nData: `0x%x`\nGas: `%d`\nValue `%v` wei\n",
+			from.String(), to.String(),
+			input, gas, value)
+	}
+
+	fmt.Fprintf(t.out, `
+|  Pc   |      Op     | Cost |   Stack   |   RStack  |
+|-------|-------------|------|-----------|-----------|
+`)
+	return nil
+}
+
+func (t *mdLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, contract *Contract, depth int, err error) error {
+	fmt.Fprintf(t.out, "| %4d  | %10v  |  %3d |", pc, op, cost)
+
+	if !t.cfg.DisableStack { // format stack
+		var a []string
+		for _, elem := range stack.data {
+			a = append(a, fmt.Sprintf("%d", elem))
+		}
+		b := fmt.Sprintf("[%v]", strings.Join(a, ","))
+		fmt.Fprintf(t.out, "%10v |", b)
+	}
+	if !t.cfg.DisableStack { // format return stack
+		var a []string
+		for _, elem := range rStack.data {
+			a = append(a, fmt.Sprintf("%2d", elem))
+		}
+		b := fmt.Sprintf("[%v]", strings.Join(a, ","))
+		fmt.Fprintf(t.out, "%10v |", b)
+	}
+	fmt.Fprintln(t.out, "")
+	if err != nil {
+		fmt.Fprintf(t.out, "Error: %v\n", err)
+	}
+	return nil
+}
+
+func (t *mdLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, contract *Contract, depth int, err error) error {
+
+	fmt.Fprintf(t.out, "\nError: at pc=%d, op=%v: %v\n", pc, op, err)
+
+	return nil
+}
+
+func (t *mdLogger) CaptureEnd(output []byte, gasUsed uint64, tm time.Duration, err error) error {
+	fmt.Fprintf(t.out, `
+Output: 0x%x
+Consumed gas: %d
+Error: %v
+`,
+		output, gasUsed, err)
+	return nil
+}

+ 3 - 2
core/vm/logger_json.go

@@ -46,7 +46,7 @@ func (l *JSONLogger) CaptureStart(from common.Address, to common.Address, create
 }
 
 // CaptureState outputs state information on the logger.
-func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error {
+func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, contract *Contract, depth int, err error) error {
 	log := StructLog{
 		Pc:            pc,
 		Op:            op,
@@ -63,12 +63,13 @@ func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint
 	}
 	if !l.cfg.DisableStack {
 		log.Stack = stack.Data()
+		log.ReturnStack = rStack.data
 	}
 	return l.encoder.Encode(log)
 }
 
 // CaptureFault outputs state information on the logger.
-func (l *JSONLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error {
+func (l *JSONLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, contract *Contract, depth int, err error) error {
 	return nil
 }
 

+ 2 - 1
core/vm/logger_test.go

@@ -54,12 +54,13 @@ func TestStoreCapture(t *testing.T) {
 		logger   = NewStructLogger(nil)
 		mem      = NewMemory()
 		stack    = newstack()
+		rstack   = newReturnStack()
 		contract = NewContract(&dummyContractRef{}, &dummyContractRef{}, new(big.Int), 0)
 	)
 	stack.push(big.NewInt(1))
 	stack.push(big.NewInt(0))
 	var index common.Hash
-	logger.CaptureState(env, 0, SSTORE, 0, 0, mem, stack, contract, 0, nil)
+	logger.CaptureState(env, 0, SSTORE, 0, 0, mem, stack, rstack, contract, 0, nil)
 	if len(logger.changedValues[contract.Address()]) == 0 {
 		t.Fatalf("expected exactly 1 changed value on address %x, got %d", contract.Address(), len(logger.changedValues[contract.Address()]))
 	}

+ 22 - 12
core/vm/opcodes.go

@@ -107,18 +107,21 @@ const (
 
 // 0x50 range - 'storage' and execution.
 const (
-	POP OpCode = 0x50 + iota
-	MLOAD
-	MSTORE
-	MSTORE8
-	SLOAD
-	SSTORE
-	JUMP
-	JUMPI
-	PC
-	MSIZE
-	GAS
-	JUMPDEST
+	POP       OpCode = 0x50
+	MLOAD     OpCode = 0x51
+	MSTORE    OpCode = 0x52
+	MSTORE8   OpCode = 0x53
+	SLOAD     OpCode = 0x54
+	SSTORE    OpCode = 0x55
+	JUMP      OpCode = 0x56
+	JUMPI     OpCode = 0x57
+	PC        OpCode = 0x58
+	MSIZE     OpCode = 0x59
+	GAS       OpCode = 0x5a
+	JUMPDEST  OpCode = 0x5b
+	BEGINSUB  OpCode = 0x5c
+	RETURNSUB OpCode = 0x5d
+	JUMPSUB   OpCode = 0x5e
 )
 
 // 0x60 range.
@@ -297,6 +300,10 @@ var opCodeToString = map[OpCode]string{
 	GAS:      "GAS",
 	JUMPDEST: "JUMPDEST",
 
+	BEGINSUB:  "BEGINSUB",
+	JUMPSUB:   "JUMPSUB",
+	RETURNSUB: "RETURNSUB",
+
 	// 0x60 range - push.
 	PUSH1:  "PUSH1",
 	PUSH2:  "PUSH2",
@@ -461,6 +468,9 @@ var stringToOp = map[string]OpCode{
 	"MSIZE":          MSIZE,
 	"GAS":            GAS,
 	"JUMPDEST":       JUMPDEST,
+	"BEGINSUB":       BEGINSUB,
+	"RETURNSUB":      RETURNSUB,
+	"JUMPSUB":        JUMPSUB,
 	"PUSH1":          PUSH1,
 	"PUSH2":          PUSH2,
 	"PUSH3":          PUSH3,

+ 249 - 0
core/vm/runtime/runtime_test.go

@@ -17,14 +17,18 @@
 package runtime
 
 import (
+	"fmt"
 	"math/big"
+	"os"
 	"strings"
 	"testing"
+	"time"
 
 	"github.com/ethereum/go-ethereum/accounts/abi"
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/consensus"
 	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/core/asm"
 	"github.com/ethereum/go-ethereum/core/rawdb"
 	"github.com/ethereum/go-ethereum/core/state"
 	"github.com/ethereum/go-ethereum/core/types"
@@ -344,3 +348,248 @@ func BenchmarkSimpleLoop(b *testing.B) {
 		Execute(code, nil, nil)
 	}
 }
+
+type stepCounter struct {
+	inner *vm.JSONLogger
+	steps int
+}
+
+func (s *stepCounter) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error {
+	return nil
+}
+
+func (s *stepCounter) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, rStack *vm.ReturnStack, contract *vm.Contract, depth int, err error) error {
+	s.steps++
+	// Enable this for more output
+	//s.inner.CaptureState(env, pc, op, gas, cost, memory, stack, rStack, contract, depth, err)
+	return nil
+}
+
+func (s *stepCounter) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, rStack *vm.ReturnStack, contract *vm.Contract, depth int, err error) error {
+	return nil
+}
+
+func (s *stepCounter) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error {
+	return nil
+}
+
+func TestJumpSub1024Limit(t *testing.T) {
+	state, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil)
+	address := common.HexToAddress("0x0a")
+	// Code is
+	// 0 beginsub
+	// 1 push 0
+	// 3 jumpsub
+	//
+	// The code recursively calls itself. It should error when the returns-stack
+	// grows above 1023
+	state.SetCode(address, []byte{
+		byte(vm.PUSH1), 3,
+		byte(vm.JUMPSUB),
+		byte(vm.BEGINSUB),
+		byte(vm.PUSH1), 3,
+		byte(vm.JUMPSUB),
+	})
+	tracer := stepCounter{inner: vm.NewJSONLogger(nil, os.Stdout)}
+	// Enable 2315
+	_, _, err := Call(address, nil, &Config{State: state,
+		GasLimit:    20000,
+		ChainConfig: params.AllEthashProtocolChanges,
+		EVMConfig: vm.Config{
+			ExtraEips: []int{2315},
+			Debug:     true,
+			//Tracer:    vm.NewJSONLogger(nil, os.Stdout),
+			Tracer: &tracer,
+		}})
+	exp := "return stack limit reached"
+	if err.Error() != exp {
+		t.Fatalf("expected %v, got %v", exp, err)
+	}
+	if exp, got := 2048, tracer.steps; exp != got {
+		t.Fatalf("expected %d steps, got %d", exp, got)
+	}
+}
+
+func TestReturnSubShallow(t *testing.T) {
+	state, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil)
+	address := common.HexToAddress("0x0a")
+	// The code does returnsub without having anything on the returnstack.
+	// It should not panic, but just fail after one step
+	state.SetCode(address, []byte{
+		byte(vm.PUSH1), 5,
+		byte(vm.JUMPSUB),
+		byte(vm.RETURNSUB),
+		byte(vm.PC),
+		byte(vm.BEGINSUB),
+		byte(vm.RETURNSUB),
+		byte(vm.PC),
+	})
+	tracer := stepCounter{}
+
+	// Enable 2315
+	_, _, err := Call(address, nil, &Config{State: state,
+		GasLimit:    10000,
+		ChainConfig: params.AllEthashProtocolChanges,
+		EVMConfig: vm.Config{
+			ExtraEips: []int{2315},
+			Debug:     true,
+			Tracer:    &tracer,
+		}})
+
+	exp := "invalid retsub"
+	if err.Error() != exp {
+		t.Fatalf("expected %v, got %v", exp, err)
+	}
+	if exp, got := 4, tracer.steps; exp != got {
+		t.Fatalf("expected %d steps, got %d", exp, got)
+	}
+}
+
+// disabled -- only used for generating markdown
+func DisabledTestReturnCases(t *testing.T) {
+	cfg := &Config{
+		EVMConfig: vm.Config{
+			Debug:     true,
+			Tracer:    vm.NewMarkdownLogger(nil, os.Stdout),
+			ExtraEips: []int{2315},
+		},
+	}
+	// This should fail at first opcode
+	Execute([]byte{
+		byte(vm.RETURNSUB),
+		byte(vm.PC),
+		byte(vm.PC),
+	}, nil, cfg)
+
+	// Should also fail
+	Execute([]byte{
+		byte(vm.PUSH1), 5,
+		byte(vm.JUMPSUB),
+		byte(vm.RETURNSUB),
+		byte(vm.PC),
+		byte(vm.BEGINSUB),
+		byte(vm.RETURNSUB),
+		byte(vm.PC),
+	}, nil, cfg)
+
+	// This should complete
+	Execute([]byte{
+		byte(vm.PUSH1), 0x4,
+		byte(vm.JUMPSUB),
+		byte(vm.STOP),
+		byte(vm.BEGINSUB),
+		byte(vm.PUSH1), 0x9,
+		byte(vm.JUMPSUB),
+		byte(vm.RETURNSUB),
+		byte(vm.BEGINSUB),
+		byte(vm.RETURNSUB),
+	}, nil, cfg)
+}
+
+// DisabledTestEipExampleCases contains various testcases that are used for the
+// EIP examples
+// This test is disabled, as it's only used for generating markdown
+func DisabledTestEipExampleCases(t *testing.T) {
+	cfg := &Config{
+		EVMConfig: vm.Config{
+			Debug:     true,
+			Tracer:    vm.NewMarkdownLogger(nil, os.Stdout),
+			ExtraEips: []int{2315},
+		},
+	}
+	prettyPrint := func(comment string, code []byte) {
+		instrs := make([]string, 0)
+		it := asm.NewInstructionIterator(code)
+		for it.Next() {
+			if it.Arg() != nil && 0 < len(it.Arg()) {
+				instrs = append(instrs, fmt.Sprintf("%v 0x%x", it.Op(), it.Arg()))
+			} else {
+				instrs = append(instrs, fmt.Sprintf("%v", it.Op()))
+			}
+		}
+		ops := strings.Join(instrs, ", ")
+
+		fmt.Printf("%v\nBytecode: `0x%x` (`%v`)\n",
+			comment,
+			code, ops)
+		Execute(code, nil, cfg)
+	}
+
+	{ // First eip testcase
+		code := []byte{
+			byte(vm.PUSH1), 4,
+			byte(vm.JUMPSUB),
+			byte(vm.STOP),
+			byte(vm.BEGINSUB),
+			byte(vm.RETURNSUB),
+		}
+		prettyPrint("This should jump into a subroutine, back out and stop.", code)
+	}
+
+	{
+		code := []byte{
+			byte(vm.PUSH9), 0x00, 0x00, 0x00, 0x00, 0x0, 0x00, 0x00, 0x00, (4 + 8),
+			byte(vm.JUMPSUB),
+			byte(vm.STOP),
+			byte(vm.BEGINSUB),
+			byte(vm.PUSH1), 8 + 9,
+			byte(vm.JUMPSUB),
+			byte(vm.RETURNSUB),
+			byte(vm.BEGINSUB),
+			byte(vm.RETURNSUB),
+		}
+		prettyPrint("This should execute fine, going into one two depths of subroutines", code)
+	}
+	// TODO(@holiman) move this test into an actual test, which not only prints
+	// out the trace.
+	{
+		code := []byte{
+			byte(vm.PUSH9), 0x01, 0x00, 0x00, 0x00, 0x0, 0x00, 0x00, 0x00, (4 + 8),
+			byte(vm.JUMPSUB),
+			byte(vm.STOP),
+			byte(vm.BEGINSUB),
+			byte(vm.PUSH1), 8 + 9,
+			byte(vm.JUMPSUB),
+			byte(vm.RETURNSUB),
+			byte(vm.BEGINSUB),
+			byte(vm.RETURNSUB),
+		}
+		prettyPrint("This should fail, since the given location is outside of the "+
+			"code-range. The code is the same as previous example, except that the "+
+			"pushed location is `0x01000000000000000c` instead of `0x0c`.", code)
+	}
+	{
+		// This should fail at first opcode
+		code := []byte{
+			byte(vm.RETURNSUB),
+			byte(vm.PC),
+			byte(vm.PC),
+		}
+		prettyPrint("This should fail at first opcode, due to shallow `return_stack`", code)
+
+	}
+	{
+		code := []byte{
+			byte(vm.PUSH1), 5, // Jump past the subroutine
+			byte(vm.JUMP),
+			byte(vm.BEGINSUB),
+			byte(vm.RETURNSUB),
+			byte(vm.JUMPDEST),
+			byte(vm.PUSH1), 3, // Now invoke the subroutine
+			byte(vm.JUMPSUB),
+		}
+		prettyPrint("In this example. the JUMPSUB is on the last byte of code. When the "+
+			"subroutine returns, it should hit the 'virtual stop' _after_ the bytecode, "+
+			"and not exit with error", code)
+	}
+
+	{
+		code := []byte{
+			byte(vm.BEGINSUB),
+			byte(vm.RETURNSUB),
+			byte(vm.STOP),
+		}
+		prettyPrint("In this example, the code 'walks' into a subroutine, which is not "+
+			"allowed, and causes an error", code)
+	}
+}

+ 19 - 0
core/vm/stack.go

@@ -86,3 +86,22 @@ func (st *Stack) Print() {
 	}
 	fmt.Println("#############")
 }
+
+// ReturnStack is an object for basic return stack operations.
+type ReturnStack struct {
+	data []uint64
+}
+
+func newReturnStack() *ReturnStack {
+	return &ReturnStack{data: make([]uint64, 0, 1024)}
+}
+
+func (st *ReturnStack) push(d uint64) {
+	st.data = append(st.data, d)
+}
+
+func (st *ReturnStack) pop() (ret uint64) {
+	ret = st.data[len(st.data)-1]
+	st.data = st.data[:len(st.data)-1]
+	return
+}

+ 2 - 2
eth/tracers/tracer.go

@@ -541,7 +541,7 @@ func (jst *Tracer) CaptureStart(from common.Address, to common.Address, create b
 }
 
 // CaptureState implements the Tracer interface to trace a single step of VM execution.
-func (jst *Tracer) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error {
+func (jst *Tracer) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, rStack *vm.ReturnStack, contract *vm.Contract, depth int, err error) error {
 	if jst.err == nil {
 		// Initialize the context if it wasn't done yet
 		if !jst.inited {
@@ -580,7 +580,7 @@ func (jst *Tracer) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost
 
 // CaptureFault implements the Tracer interface to trace an execution fault
 // while running an opcode.
-func (jst *Tracer) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error {
+func (jst *Tracer) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, rStack *vm.ReturnStack, contract *vm.Contract, depth int, err error) error {
 	if jst.err == nil {
 		// Apart from the error, everything matches the previous invocation
 		jst.errorValue = new(string)

+ 2 - 2
eth/tracers/tracer_test.go

@@ -169,10 +169,10 @@ func TestHaltBetweenSteps(t *testing.T) {
 	env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer})
 	contract := vm.NewContract(&account{}, &account{}, big.NewInt(0), 0)
 
-	tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, contract, 0, nil)
+	tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, nil, contract, 0, nil)
 	timeout := errors.New("stahp")
 	tracer.Stop(timeout)
-	tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, contract, 0, nil)
+	tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, nil, contract, 0, nil)
 
 	if _, err := tracer.GetResult(); err.Error() != timeout.Error() {
 		t.Errorf("Expected timeout error, got %v", err)