Browse Source

add chain id into sign bytes to avoid replay attack (#18)

zjubfd 5 years ago
parent
commit
f4816ee8b7
3 changed files with 20 additions and 18 deletions
  1. 13 12
      consensus/parlia/parlia.go
  2. 3 2
      consensus/parlia/snapshot.go
  3. 4 4
      signer/core/signed_data.go

+ 13 - 12
consensus/parlia/parlia.go

@@ -156,7 +156,7 @@ func isToSystemContract(to common.Address) bool {
 }
 
 // ecrecover extracts the Ethereum account address from a signed header.
-func ecrecover(header *types.Header, sigCache *lru.ARCCache) (common.Address, error) {
+func ecrecover(header *types.Header, sigCache *lru.ARCCache, chainId *big.Int) (common.Address, error) {
 	// If the signature's already cached, return that
 	hash := header.Hash()
 	if address, known := sigCache.Get(hash); known {
@@ -169,7 +169,7 @@ func ecrecover(header *types.Header, sigCache *lru.ARCCache) (common.Address, er
 	signature := header.Extra[len(header.Extra)-extraSeal:]
 
 	// Recover the public key and the Ethereum address
-	pubkey, err := crypto.Ecrecover(SealHash(header).Bytes(), signature)
+	pubkey, err := crypto.Ecrecover(SealHash(header, chainId).Bytes(), signature)
 	if err != nil {
 		return common.Address{}, err
 	}
@@ -187,9 +187,9 @@ func ecrecover(header *types.Header, sigCache *lru.ARCCache) (common.Address, er
 // Note, the method requires the extra data to be at least 65 bytes, otherwise it
 // panics. This is done to avoid accidentally using both forms (signature present
 // or not), which could be abused to produce different hashes for the same header.
-func ParliaRLP(header *types.Header) []byte {
+func ParliaRLP(header *types.Header, chainId *big.Int) []byte {
 	b := new(bytes.Buffer)
-	encodeSigHeader(b, header)
+	encodeSigHeader(b, header, chainId)
 	return b.Bytes()
 }
 
@@ -498,7 +498,7 @@ func (p *Parlia) snapshot(chain consensus.ChainReader, number uint64, hash commo
 		headers[i], headers[len(headers)-1-i] = headers[len(headers)-1-i], headers[i]
 	}
 
-	snap, err := snap.apply(headers, chain, parents)
+	snap, err := snap.apply(headers, chain, parents, p.chainConfig.ChainID)
 	if err != nil {
 		return nil, err
 	}
@@ -546,7 +546,7 @@ func (p *Parlia) verifySeal(chain consensus.ChainReader, header *types.Header, p
 	}
 
 	// Resolve the authorization key and check against validators
-	signer, err := ecrecover(header, p.signatures)
+	signer, err := ecrecover(header, p.signatures, p.chainConfig.ChainID)
 	if err != nil {
 		return err
 	}
@@ -821,7 +821,7 @@ func (p *Parlia) Seal(chain consensus.ChainReader, block *types.Block, results c
 	log.Info("Sealing block with", "number", number, "delay", delay, "headerDifficulty", header.Difficulty, "val", val.Hex())
 
 	// Sign all the things!
-	sig, err := signFn(accounts.Account{Address: val}, accounts.MimetypeParlia, ParliaRLP(header))
+	sig, err := signFn(accounts.Account{Address: val}, accounts.MimetypeParlia, ParliaRLP(header, p.chainConfig.ChainID))
 	if err != nil {
 		return err
 	}
@@ -839,7 +839,7 @@ func (p *Parlia) Seal(chain consensus.ChainReader, block *types.Block, results c
 		select {
 		case results <- block.WithSeal(header):
 		default:
-			log.Warn("Sealing result is not read by miner", "sealhash", SealHash(header))
+			log.Warn("Sealing result is not read by miner", "sealhash", SealHash(header, p.chainConfig.ChainID))
 		}
 	}()
 
@@ -869,7 +869,7 @@ func CalcDifficulty(snap *Snapshot, signer common.Address) *big.Int {
 
 // SealHash returns the hash of a block prior to it being sealed.
 func (p *Parlia) SealHash(header *types.Header) common.Hash {
-	return SealHash(header)
+	return SealHash(header, p.chainConfig.ChainID)
 }
 
 // APIs implements consensus.Engine, returning the user facing RPC API to query snapshot.
@@ -1109,15 +1109,16 @@ func (p *Parlia) applyTransaction(
 
 // ===========================     utility function        ==========================
 // SealHash returns the hash of a block prior to it being sealed.
-func SealHash(header *types.Header) (hash common.Hash) {
+func SealHash(header *types.Header, chainId *big.Int) (hash common.Hash) {
 	hasher := sha3.NewLegacyKeccak256()
-	encodeSigHeader(hasher, header)
+	encodeSigHeader(hasher, header, chainId)
 	hasher.Sum(hash[:0])
 	return hash
 }
 
-func encodeSigHeader(w io.Writer, header *types.Header) {
+func encodeSigHeader(w io.Writer, header *types.Header, chainId *big.Int) {
 	err := rlp.Encode(w, []interface{}{
+		chainId,
 		header.ParentHash,
 		header.UncleHash,
 		header.Coinbase,

+ 3 - 2
consensus/parlia/snapshot.go

@@ -20,6 +20,7 @@ import (
 	"bytes"
 	"encoding/json"
 	"errors"
+	"math/big"
 	"sort"
 
 	"github.com/ethereum/go-ethereum/common"
@@ -123,7 +124,7 @@ func (s *Snapshot) copy() *Snapshot {
 	return cpy
 }
 
-func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainReader, parents []*types.Header) (*Snapshot, error) {
+func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainReader, parents []*types.Header, chainId *big.Int) (*Snapshot, error) {
 	// Allow passing in no headers for cleaner code
 	if len(headers) == 0 {
 		return s, nil
@@ -153,7 +154,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainReader, p
 			delete(snap.Recents, number-limit)
 		}
 		// Resolve the authorization key and check against signers
-		validator, err := ecrecover(header, s.sigCache)
+		validator, err := ecrecover(header, s.sigCache, chainId)
 		if err != nil {
 			return nil, err
 		}

+ 4 - 4
signer/core/signed_data.go

@@ -285,7 +285,7 @@ func (api *SignerAPI) determineSignatureFormat(ctx context.Context, contentType
 			header.Extra = newExtra
 		}
 		// Get back the rlp data, encoded by us
-		sighash, parliaRlp, err := parliaHeaderHashAndRlp(header)
+		sighash, parliaRlp, err := parliaHeaderHashAndRlp(header, api.chainID)
 		if err != nil {
 			return nil, useEthereumV, err
 		}
@@ -351,13 +351,13 @@ func cliqueHeaderHashAndRlp(header *types.Header) (hash, rlp []byte, err error)
 	return hash, rlp, err
 }
 
-func parliaHeaderHashAndRlp(header *types.Header) (hash, rlp []byte, err error) {
+func parliaHeaderHashAndRlp(header *types.Header, chainId *big.Int) (hash, rlp []byte, err error) {
 	if len(header.Extra) < 65 {
 		err = fmt.Errorf("clique header extradata too short, %d < 65", len(header.Extra))
 		return
 	}
-	rlp = parlia.ParliaRLP(header)
-	hash = parlia.SealHash(header).Bytes()
+	rlp = parlia.ParliaRLP(header, chainId)
+	hash = parlia.SealHash(header, chainId).Bytes()
 	return hash, rlp, err
 }