Sfoglia il codice sorgente

simplify account unlocking

zelig 10 anni fa
parent
commit
1d72aaa0cd
2 ha cambiato i file con 85 aggiunte e 56 eliminazioni
  1. 43 55
      accounts/account_manager.go
  2. 42 1
      accounts/accounts_test.go

+ 43 - 55
accounts/account_manager.go

@@ -49,11 +49,6 @@ var (
 	ErrNoKeys = errors.New("no keys in store")
 )
 
-const (
-    // Default unlock duration (in seconds) when an account is unlocked from the console
-    DefaultAccountUnlockDuration = 300
-)
-
 type Account struct {
 	Address common.Address
 }
@@ -114,28 +109,58 @@ func (am *Manager) Sign(a Account, toSign []byte) (signature []byte, err error)
 	return signature, err
 }
 
-// TimedUnlock unlocks the account with the given address.
-// When timeout has passed, the account will be locked again.
+// unlock indefinitely
+func (am *Manager) Unlock(addr common.Address, keyAuth string) error {
+	return am.TimedUnlock(addr, keyAuth, 0)
+}
+
+// Unlock unlocks the account with the given address. The account
+// stays unlocked for the duration of timeout
+// it timeout is 0 the account is unlocked for the entire session
 func (am *Manager) TimedUnlock(addr common.Address, keyAuth string, timeout time.Duration) error {
 	key, err := am.keyStore.GetKey(addr, keyAuth)
 	if err != nil {
 		return err
 	}
-	u := am.addUnlocked(addr, key)
-	go am.dropLater(addr, u, timeout)
+	var u *unlocked
+	am.mutex.Lock()
+	defer am.mutex.Unlock()
+	var found bool
+	u, found = am.unlocked[addr]
+	if found {
+		// terminate dropLater for this key to avoid unexpected drops.
+		if u.abort != nil {
+			close(u.abort)
+		}
+	}
+	if timeout > 0 {
+		u = &unlocked{Key: key, abort: make(chan struct{})}
+		go am.expire(addr, u, timeout)
+	} else {
+		u = &unlocked{Key: key}
+	}
+	am.unlocked[addr] = u
 	return nil
 }
 
-// Unlock unlocks the account with the given address. The account
-// stays unlocked until the program exits or until a TimedUnlock
-// timeout (started after the call to Unlock) expires.
-func (am *Manager) Unlock(addr common.Address, keyAuth string) error {
-	key, err := am.keyStore.GetKey(addr, keyAuth)
-	if err != nil {
-		return err
+func (am *Manager) expire(addr common.Address, u *unlocked, timeout time.Duration) {
+	t := time.NewTimer(timeout)
+	defer t.Stop()
+	select {
+	case <-u.abort:
+		// just quit
+	case <-t.C:
+		am.mutex.Lock()
+		// only drop if it's still the same key instance that dropLater
+		// was launched with. we can check that using pointer equality
+		// because the map stores a new pointer every time the key is
+		// unlocked.
+		if am.unlocked[addr] == u {
+			zeroKey(u.PrivateKey)
+			delete(am.unlocked, addr)
+		}
+		am.mutex.Unlock()
 	}
-	am.addUnlocked(addr, key)
-	return nil
 }
 
 func (am *Manager) NewAccount(auth string) (Account, error) {
@@ -162,43 +187,6 @@ func (am *Manager) Accounts() ([]Account, error) {
 	return accounts, err
 }
 
-func (am *Manager) addUnlocked(addr common.Address, key *crypto.Key) *unlocked {
-	u := &unlocked{Key: key, abort: make(chan struct{})}
-	am.mutex.Lock()
-	prev, found := am.unlocked[addr]
-	if found {
-		// terminate dropLater for this key to avoid unexpected drops.
-		close(prev.abort)
-		// the key is zeroed here instead of in dropLater because
-		// there might not actually be a dropLater running for this
-		// key, i.e. when Unlock was used.
-		zeroKey(prev.PrivateKey)
-	}
-	am.unlocked[addr] = u
-	am.mutex.Unlock()
-	return u
-}
-
-func (am *Manager) dropLater(addr common.Address, u *unlocked, timeout time.Duration) {
-	t := time.NewTimer(timeout)
-	defer t.Stop()
-	select {
-	case <-u.abort:
-		// just quit
-	case <-t.C:
-		am.mutex.Lock()
-		// only drop if it's still the same key instance that dropLater
-		// was launched with. we can check that using pointer equality
-		// because the map stores a new pointer every time the key is
-		// unlocked.
-		if am.unlocked[addr] == u {
-			zeroKey(u.PrivateKey)
-			delete(am.unlocked, addr)
-		}
-		am.mutex.Unlock()
-	}
-}
-
 // zeroKey zeroes a private key in memory.
 func zeroKey(k *ecdsa.PrivateKey) {
 	b := k.D.Bits()

+ 42 - 1
accounts/accounts_test.go

@@ -18,7 +18,7 @@ func TestSign(t *testing.T) {
 	pass := "" // not used but required by API
 	a1, err := am.NewAccount(pass)
 	toSign := randentropy.GetEntropyCSPRNG(32)
-	am.Unlock(a1.Address, "")
+	am.Unlock(a1.Address, "", 0)
 
 	_, err = am.Sign(a1, toSign)
 	if err != nil {
@@ -58,6 +58,47 @@ func TestTimedUnlock(t *testing.T) {
 	if err != ErrLocked {
 		t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err)
 	}
+
+}
+
+func TestOverrideUnlock(t *testing.T) {
+	dir, ks := tmpKeyStore(t, crypto.NewKeyStorePassphrase)
+	defer os.RemoveAll(dir)
+
+	am := NewManager(ks)
+	pass := "foo"
+	a1, err := am.NewAccount(pass)
+	toSign := randentropy.GetEntropyCSPRNG(32)
+
+	// Unlock indefinitely
+	if err = am.Unlock(a1.Address, pass); err != nil {
+		t.Fatal(err)
+	}
+
+	// Signing without passphrase works because account is temp unlocked
+	_, err = am.Sign(a1, toSign)
+	if err != nil {
+		t.Fatal("Signing shouldn't return an error after unlocking, got ", err)
+	}
+
+	// reset unlock to a shorter period, invalidates the previous unlock
+	if err = am.TimedUnlock(a1.Address, pass, 100*time.Millisecond); err != nil {
+		t.Fatal(err)
+	}
+
+	// Signing without passphrase still works because account is temp unlocked
+	_, err = am.Sign(a1, toSign)
+	if err != nil {
+		t.Fatal("Signing shouldn't return an error after unlocking, got ", err)
+	}
+
+	// Signing fails again after automatic locking
+	time.Sleep(150 * time.Millisecond)
+	_, err = am.Sign(a1, toSign)
+	if err != ErrLocked {
+		t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err)
+	}
+
 }
 
 func tmpKeyStore(t *testing.T, new func(string) crypto.KeyStore2) (string, crypto.KeyStore2) {