Procházet zdrojové kódy

accounts/keystore: fix double import race (#20915)

* accounts/keystore: fix race in Import/ImportECDSA

* accounts/keystore: added import/export tests

* cmd/geth: improved TestAccountImport test

* accounts/keystore: added import/export tests

* accounts/keystore: fixed naming

* accounts/keystore: fixed typo

* accounts/keystore: use mutex instead of rwmutex

* accounts: use errors instead of fmt
Marius van der Wijden před 5 roky
rodič
revize
38aab0aa83

+ 10 - 3
accounts/keystore/keystore.go

@@ -24,7 +24,6 @@ import (
 	"crypto/ecdsa"
 	crand "crypto/rand"
 	"errors"
-	"fmt"
 	"math/big"
 	"os"
 	"path/filepath"
@@ -67,7 +66,8 @@ type KeyStore struct {
 	updateScope event.SubscriptionScope // Subscription scope tracking current live listeners
 	updating    bool                    // Whether the event notification loop is running
 
-	mu sync.RWMutex
+	mu       sync.RWMutex
+	importMu sync.Mutex // Import Mutex locks the import to prevent two insertions from racing
 }
 
 type unlocked struct {
@@ -443,14 +443,21 @@ func (ks *KeyStore) Import(keyJSON []byte, passphrase, newPassphrase string) (ac
 	if err != nil {
 		return accounts.Account{}, err
 	}
+	ks.importMu.Lock()
+	defer ks.importMu.Unlock()
+	if ks.cache.hasAddress(key.Address) {
+		return accounts.Account{}, errors.New("account already exists")
+	}
 	return ks.importKey(key, newPassphrase)
 }
 
 // ImportECDSA stores the given key into the key directory, encrypting it with the passphrase.
 func (ks *KeyStore) ImportECDSA(priv *ecdsa.PrivateKey, passphrase string) (accounts.Account, error) {
 	key := newKeyFromECDSA(priv)
+	ks.importMu.Lock()
+	defer ks.importMu.Unlock()
 	if ks.cache.hasAddress(key.Address) {
-		return accounts.Account{}, fmt.Errorf("account already exists")
+		return accounts.Account{}, errors.New("account already exists")
 	}
 	return ks.importKey(key, passphrase)
 }

+ 85 - 0
accounts/keystore/keystore_test.go

@@ -23,11 +23,14 @@ import (
 	"runtime"
 	"sort"
 	"strings"
+	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 
 	"github.com/ethereum/go-ethereum/accounts"
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/event"
 )
 
@@ -338,6 +341,88 @@ func TestWalletNotifications(t *testing.T) {
 	checkEvents(t, wantEvents, events)
 }
 
+// TestImportExport tests the import functionality of a keystore.
+func TestImportECDSA(t *testing.T) {
+	dir, ks := tmpKeyStore(t, true)
+	defer os.RemoveAll(dir)
+	key, err := crypto.GenerateKey()
+	if err != nil {
+		t.Fatalf("failed to generate key: %v", key)
+	}
+	if _, err = ks.ImportECDSA(key, "old"); err != nil {
+		t.Errorf("importing failed: %v", err)
+	}
+	if _, err = ks.ImportECDSA(key, "old"); err == nil {
+		t.Errorf("importing same key twice succeeded")
+	}
+	if _, err = ks.ImportECDSA(key, "new"); err == nil {
+		t.Errorf("importing same key twice succeeded")
+	}
+}
+
+// TestImportECDSA tests the import and export functionality of a keystore.
+func TestImportExport(t *testing.T) {
+	dir, ks := tmpKeyStore(t, true)
+	defer os.RemoveAll(dir)
+	acc, err := ks.NewAccount("old")
+	if err != nil {
+		t.Fatalf("failed to create account: %v", acc)
+	}
+	json, err := ks.Export(acc, "old", "new")
+	if err != nil {
+		t.Fatalf("failed to export account: %v", acc)
+	}
+	dir2, ks2 := tmpKeyStore(t, true)
+	defer os.RemoveAll(dir2)
+	if _, err = ks2.Import(json, "old", "old"); err == nil {
+		t.Errorf("importing with invalid password succeeded")
+	}
+	acc2, err := ks2.Import(json, "new", "new")
+	if err != nil {
+		t.Errorf("importing failed: %v", err)
+	}
+	if acc.Address != acc2.Address {
+		t.Error("imported account does not match exported account")
+	}
+	if _, err = ks2.Import(json, "new", "new"); err == nil {
+		t.Errorf("importing a key twice succeeded")
+	}
+
+}
+
+// TestImportRace tests the keystore on races.
+// This test should fail under -race if importing races.
+func TestImportRace(t *testing.T) {
+	dir, ks := tmpKeyStore(t, true)
+	defer os.RemoveAll(dir)
+	acc, err := ks.NewAccount("old")
+	if err != nil {
+		t.Fatalf("failed to create account: %v", acc)
+	}
+	json, err := ks.Export(acc, "old", "new")
+	if err != nil {
+		t.Fatalf("failed to export account: %v", acc)
+	}
+	dir2, ks2 := tmpKeyStore(t, true)
+	defer os.RemoveAll(dir2)
+	var atom uint32
+	var wg sync.WaitGroup
+	wg.Add(2)
+	for i := 0; i < 2; i++ {
+		go func() {
+			defer wg.Done()
+			if _, err := ks2.Import(json, "new", "new"); err != nil {
+				atomic.AddUint32(&atom, 1)
+			}
+
+		}()
+	}
+	wg.Wait()
+	if atom != 1 {
+		t.Errorf("Import is racy")
+	}
+}
+
 // checkAccounts checks that all known live accounts are present in the wallet list.
 func checkAccounts(t *testing.T, live map[common.Address]accounts.Account, wallets []accounts.Wallet) {
 	if len(live) != len(wallets) {

+ 7 - 2
cmd/geth/accountcmd_test.go

@@ -89,18 +89,23 @@ Path of the secret key file: .*UTC--.+--[0-9a-f]{40}
 }
 
 func TestAccountImport(t *testing.T) {
-	tests := []struct{ key, output string }{
+	tests := []struct{ name, key, output string }{
 		{
+			name:   "correct account",
 			key:    "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
 			output: "Address: {fcad0b19bb29d4674531d6f115237e16afce377c}\n",
 		},
 		{
+			name:   "invalid character",
 			key:    "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef1",
 			output: "Fatal: Failed to load the private key: invalid character '1' at end of key file\n",
 		},
 	}
 	for _, test := range tests {
-		importAccountWithExpect(t, test.key, test.output)
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+			importAccountWithExpect(t, test.key, test.output)
+		})
 	}
 }