瀏覽代碼

eth, eth/downloader: prevent hash repeater attack

Péter Szilágyi 10 年之前
父節點
當前提交
cd2fb09051
共有 4 個文件被更改,包括 50 次插入15 次删除
  1. 6 4
      eth/downloader/downloader.go
  2. 29 0
      eth/downloader/downloader_test.go
  3. 11 9
      eth/downloader/queue.go
  4. 4 2
      eth/sync.go

+ 6 - 4
eth/downloader/downloader.go

@@ -27,7 +27,7 @@ var (
 	errLowTd               = errors.New("peer's TD is too low")
 	ErrBusy                = errors.New("busy")
 	errUnknownPeer         = errors.New("peer's unknown or unhealthy")
-	errBadPeer             = errors.New("action from bad peer ignored")
+	ErrBadPeer             = errors.New("action from bad peer ignored")
 	errNoPeers             = errors.New("no peers to keep download active")
 	ErrPendingQueue        = errors.New("pending items in queue")
 	ErrTimeout             = errors.New("timeout")
@@ -266,9 +266,11 @@ out:
 					break
 				}
 			}
-			d.queue.Insert(hashPack.hashes)
-
-			if !done {
+			// Insert all the new hashes, but only continue if got something useful
+			inserts := d.queue.Insert(hashPack.hashes)
+			if inserts == 0 && !done {
+				return ErrBadPeer
+			} else if !done {
 				activePeer.getHashes(hash)
 				continue
 			}

+ 29 - 0
eth/downloader/downloader_test.go

@@ -336,3 +336,32 @@ func TestNonExistingParentAttack(t *testing.T) {
 		t.Fatalf("tester doesn't know about the origin hash")
 	}
 }
+
+// Tests that if a malicious peers keeps sending us repeating hashes, we don't
+// loop indefinitely.
+func TestRepeatingHashAttack(t *testing.T) {
+	// Create a valid chain, but drop the last link
+	hashes := createHashes(1000, 1)
+	blocks := createBlocksFromHashes(hashes)
+
+	hashes = hashes[:len(hashes)-1]
+
+	// Try and sync with the malicious node
+	tester := newTester(t, hashes, blocks)
+	tester.newPeer("attack", big.NewInt(10000), hashes[0])
+
+	errc := make(chan error)
+	go func() {
+		errc <- tester.sync("attack", hashes[0])
+	}()
+
+	// Make sure that syncing returns and does so with a failure
+	select {
+	case <-time.After(100 * time.Millisecond):
+		t.Fatalf("synchronisation blocked")
+	case err := <-errc:
+		if err == nil {
+			t.Fatalf("synchronisation succeeded")
+		}
+	}
+}

+ 11 - 9
eth/downloader/queue.go

@@ -122,24 +122,26 @@ func (q *queue) Has(hash common.Hash) bool {
 	return false
 }
 
-// Insert adds a set of hashes for the download queue for scheduling.
-func (q *queue) Insert(hashes []common.Hash) {
+// Insert adds a set of hashes for the download queue for scheduling, returning
+// the number of new hashes encountered.
+func (q *queue) Insert(hashes []common.Hash) int {
 	q.lock.Lock()
 	defer q.lock.Unlock()
 
 	// Insert all the hashes prioritized in the arrival order
-	for i, hash := range hashes {
-		index := q.hashCounter + i
-
+	inserts := 0
+	for _, hash := range hashes {
+		// Skip anything we already have
 		if old, ok := q.hashPool[hash]; ok {
 			glog.V(logger.Warn).Infof("Hash %x already scheduled at index %v", hash, old)
 			continue
 		}
-		q.hashPool[hash] = index
-		q.hashQueue.Push(hash, float32(index)) // Highest gets schedules first
+		// Update the counters and insert the hash
+		q.hashCounter, inserts = q.hashCounter+1, inserts+1
+		q.hashPool[hash] = q.hashCounter
+		q.hashQueue.Push(hash, float32(q.hashCounter)) // Highest gets schedules first
 	}
-	// Update the hash counter for the next batch of inserts
-	q.hashCounter += len(hashes)
+	return inserts
 }
 
 // GetHeadBlock retrieves the first block from the cache, or nil if it hasn't

+ 4 - 2
eth/sync.go

@@ -101,11 +101,13 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
 	case downloader.ErrBusy:
 		glog.V(logger.Debug).Infof("Synchronisation already in progress")
 
-	case downloader.ErrTimeout:
-		glog.V(logger.Debug).Infof("Removing peer %v due to sync timeout", peer.id)
+	case downloader.ErrTimeout, downloader.ErrBadPeer:
+		glog.V(logger.Debug).Infof("Removing peer %v: %v", peer.id, err)
 		pm.removePeer(peer)
+
 	case downloader.ErrPendingQueue:
 		glog.V(logger.Debug).Infoln("Synchronisation aborted:", err)
+
 	default:
 		glog.V(logger.Warn).Infof("Synchronisation failed: %v", err)
 	}