소스 검색

cmd/utils, eth: minor polishes on whitelist code

Péter Szilágyi 7 년 전
부모
커밋
31b3334922
2개의 변경된 파일29개의 추가작업 그리고 41개의 파일을 삭제
  1. 18 23
      cmd/utils/flags.go
  2. 11 18
      eth/handler.go

+ 18 - 23
cmd/utils/flags.go

@@ -1077,30 +1077,25 @@ func setEthash(ctx *cli.Context, cfg *eth.Config) {
 }
 
 func setWhitelist(ctx *cli.Context, cfg *eth.Config) {
-	if ctx.GlobalIsSet(WhitelistFlag.Name) {
-		entries := strings.Split(ctx.String(WhitelistFlag.Name), ",")
-		whitelist := make(map[uint64]common.Hash)
-		for _, entry := range entries {
-			split := strings.SplitN(entry, "=", 2)
-			if len(split) != 2 {
-				Fatalf("invalid whitelist entry: %s", entry)
-			}
-
-			bn, err := strconv.ParseUint(split[0], 0, 64)
-			if err != nil {
-				Fatalf("Invalid whitelist block number %s: %v", split[0], err)
-			}
-
-			hash := common.Hash{}
-			err = hash.UnmarshalText([]byte(split[1]))
-			if err != nil {
-				Fatalf("Invalid whitelist hash %s: %v", split[1], err)
-			}
-
-			whitelist[bn] = hash
+	whitelist := ctx.GlobalString(WhitelistFlag.Name)
+	if whitelist == "" {
+		return
+	}
+	cfg.Whitelist = make(map[uint64]common.Hash)
+	for _, entry := range strings.Split(whitelist, ",") {
+		parts := strings.Split(entry, "=")
+		if len(parts) != 2 {
+			Fatalf("Invalid whitelist entry: %s", entry)
 		}
-
-		cfg.Whitelist = whitelist
+		number, err := strconv.ParseUint(parts[0], 0, 64)
+		if err != nil {
+			Fatalf("Invalid whitelist block number %s: %v", parts[0], err)
+		}
+		var hash common.Hash
+		if err = hash.UnmarshalText([]byte(parts[1])); err != nil {
+			Fatalf("Invalid whitelist hash %s: %v", parts[1], err)
+		}
+		cfg.Whitelist[number] = hash
 	}
 }
 

+ 11 - 18
eth/handler.go

@@ -17,7 +17,6 @@
 package eth
 
 import (
-	"bytes"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -311,17 +310,13 @@ func (pm *ProtocolManager) handle(p *peer) error {
 			}
 		}()
 	}
-
 	// If we have any explicit whitelist block hashes, request them
-	for bn := range pm.whitelist {
-		p.Log().Debug("Requesting whitelist block", "number", bn)
-		if err := p.RequestHeadersByNumber(bn, 1, 0, false); err != nil {
-			p.Log().Error("whitelist request failed", "err", err, "number", bn, "peer", p.id)
+	for number := range pm.whitelist {
+		if err := p.RequestHeadersByNumber(number, 1, 0, false); err != nil {
 			return err
 		}
 	}
-
-	// main loop. handle incoming messages.
+	// Handle incoming messages until the connection is torn down
 	for {
 		if err := pm.handleMsg(p); err != nil {
 			p.Log().Debug("Ethereum message handling failed", "err", err)
@@ -466,16 +461,6 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		// Filter out any explicitly requested headers, deliver the rest to the downloader
 		filter := len(headers) == 1
 		if filter {
-			// Check for any responses not matching our whitelist
-			if expected, ok := pm.whitelist[headers[0].Number.Uint64()]; ok {
-				actual := headers[0].Hash()
-				if !bytes.Equal(expected.Bytes(), actual.Bytes()) {
-					p.Log().Info("Dropping peer with non-matching whitelist block", "number", headers[0].Number.Uint64(), "hash", actual, "expected", expected)
-					return errors.New("whitelist block mismatch")
-				}
-				p.Log().Debug("Whitelist block verified", "number", headers[0].Number.Uint64(), "hash", expected)
-			}
-
 			// If it's a potential DAO fork check, validate against the rules
 			if p.forkDrop != nil && pm.chainconfig.DAOForkBlock.Cmp(headers[0].Number) == 0 {
 				// Disable the fork drop timer
@@ -490,6 +475,14 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 				p.Log().Debug("Verified to be on the same side of the DAO fork")
 				return nil
 			}
+			// Otherwise if it's a whitelisted block, validate against the set
+			if want, ok := pm.whitelist[headers[0].Number.Uint64()]; ok {
+				if hash := headers[0].Hash(); want != hash {
+					p.Log().Info("Whitelist mismatch, dropping peer", "number", headers[0].Number.Uint64(), "hash", hash, "want", want)
+					return errors.New("whitelist block mismatch")
+				}
+				p.Log().Debug("Whitelist block verified", "number", headers[0].Number.Uint64(), "hash", want)
+			}
 			// Irrelevant of the fork checks, send the header to the fetcher just in case
 			headers = pm.fetcher.FilterHeaders(p.id, headers, time.Now())
 		}