浏览代码

Encryption async api (#17603)

* swarm/storage/encryption: async segmentwise encryption/decryption

* swarm/storage: adapt hasherstore to encryption API change

* swarm/api: adapt RefEncryption for AC to new Encryption API

* swarm/storage/encryption: address review comments
Viktor Trón 7 年之前
父节点
当前提交
6dd87483d4
共有 4 个文件被更改,包括 146 次插入103 次删除
  1. 12 10
      swarm/api/encrypt.go
  2. 81 45
      swarm/storage/encryption/encryption.go
  3. 1 0
      swarm/storage/encryption/encryption_test.go
  4. 52 48
      swarm/storage/hasherstore.go

+ 12 - 10
swarm/api/encrypt.go

@@ -25,27 +25,27 @@ import (
 )
 
 type RefEncryption struct {
-	spanEncryption encryption.Encryption
-	dataEncryption encryption.Encryption
-	span           []byte
+	refSize int
+	span    []byte
 }
 
 func NewRefEncryption(refSize int) *RefEncryption {
 	span := make([]byte, 8)
 	binary.LittleEndian.PutUint64(span, uint64(refSize))
 	return &RefEncryption{
-		spanEncryption: encryption.New(0, uint32(refSize/32), sha3.NewKeccak256),
-		dataEncryption: encryption.New(refSize, 0, sha3.NewKeccak256),
-		span:           span,
+		refSize: refSize,
+		span:    span,
 	}
 }
 
 func (re *RefEncryption) Encrypt(ref []byte, key []byte) ([]byte, error) {
-	encryptedSpan, err := re.spanEncryption.Encrypt(re.span, key)
+	spanEncryption := encryption.New(key, 0, uint32(re.refSize/32), sha3.NewKeccak256)
+	encryptedSpan, err := spanEncryption.Encrypt(re.span)
 	if err != nil {
 		return nil, err
 	}
-	encryptedData, err := re.dataEncryption.Encrypt(ref, key)
+	dataEncryption := encryption.New(key, re.refSize, 0, sha3.NewKeccak256)
+	encryptedData, err := dataEncryption.Encrypt(ref)
 	if err != nil {
 		return nil, err
 	}
@@ -57,7 +57,8 @@ func (re *RefEncryption) Encrypt(ref []byte, key []byte) ([]byte, error) {
 }
 
 func (re *RefEncryption) Decrypt(ref []byte, key []byte) ([]byte, error) {
-	decryptedSpan, err := re.spanEncryption.Decrypt(ref[:8], key)
+	spanEncryption := encryption.New(key, 0, uint32(re.refSize/32), sha3.NewKeccak256)
+	decryptedSpan, err := spanEncryption.Decrypt(ref[:8])
 	if err != nil {
 		return nil, err
 	}
@@ -67,7 +68,8 @@ func (re *RefEncryption) Decrypt(ref []byte, key []byte) ([]byte, error) {
 		return nil, errors.New("invalid span in encrypted reference")
 	}
 
-	decryptedRef, err := re.dataEncryption.Decrypt(ref[8:], key)
+	dataEncryption := encryption.New(key, re.refSize, 0, sha3.NewKeccak256)
+	decryptedRef, err := dataEncryption.Decrypt(ref[8:])
 	if err != nil {
 		return nil, err
 	}

+ 81 - 45
swarm/storage/encryption/encryption.go

@@ -21,6 +21,7 @@ import (
 	"encoding/binary"
 	"fmt"
 	"hash"
+	"sync"
 )
 
 const KeyLength = 32
@@ -28,84 +29,119 @@ const KeyLength = 32
 type Key []byte
 
 type Encryption interface {
-	Encrypt(data []byte, key Key) ([]byte, error)
-	Decrypt(data []byte, key Key) ([]byte, error)
+	Encrypt(data []byte) ([]byte, error)
+	Decrypt(data []byte) ([]byte, error)
 }
 
 type encryption struct {
-	padding  int
-	initCtr  uint32
-	hashFunc func() hash.Hash
+	key      Key              // the encryption key (hashSize bytes long)
+	keyLen   int              // length of the key = length of blockcipher block
+	padding  int              // encryption will pad the data upto this if > 0
+	initCtr  uint32           // initial counter used for counter mode blockcipher
+	hashFunc func() hash.Hash // hasher constructor function
 }
 
-func New(padding int, initCtr uint32, hashFunc func() hash.Hash) *encryption {
+// New constructs a new encryptor/decryptor
+func New(key Key, padding int, initCtr uint32, hashFunc func() hash.Hash) *encryption {
 	return &encryption{
+		key:      key,
+		keyLen:   len(key),
 		padding:  padding,
 		initCtr:  initCtr,
 		hashFunc: hashFunc,
 	}
 }
 
-func (e *encryption) Encrypt(data []byte, key Key) ([]byte, error) {
+// Encrypt encrypts the data and does padding if specified
+func (e *encryption) Encrypt(data []byte) ([]byte, error) {
 	length := len(data)
+	outLength := length
 	isFixedPadding := e.padding > 0
-	if isFixedPadding && length > e.padding {
-		return nil, fmt.Errorf("Data length longer than padding, data length %v padding %v", length, e.padding)
-	}
-
-	paddedData := data
-	if isFixedPadding && length < e.padding {
-		paddedData = make([]byte, e.padding)
-		copy(paddedData[:length], data)
-		rand.Read(paddedData[length:])
+	if isFixedPadding {
+		if length > e.padding {
+			return nil, fmt.Errorf("Data length longer than padding, data length %v padding %v", length, e.padding)
+		}
+		outLength = e.padding
 	}
-	return e.transform(paddedData, key), nil
+	out := make([]byte, outLength)
+	e.transform(data, out)
+	return out, nil
 }
 
-func (e *encryption) Decrypt(data []byte, key Key) ([]byte, error) {
+// Decrypt decrypts the data, if padding was used caller must know original length and truncate
+func (e *encryption) Decrypt(data []byte) ([]byte, error) {
 	length := len(data)
 	if e.padding > 0 && length != e.padding {
 		return nil, fmt.Errorf("Data length different than padding, data length %v padding %v", length, e.padding)
 	}
+	out := make([]byte, length)
+	e.transform(data, out)
+	return out, nil
+}
 
-	return e.transform(data, key), nil
+//
+func (e *encryption) transform(in, out []byte) {
+	inLength := len(in)
+	wg := sync.WaitGroup{}
+	wg.Add((inLength-1)/e.keyLen + 1)
+	for i := 0; i < inLength; i += e.keyLen {
+		l := min(e.keyLen, inLength-i)
+		// call transformations per segment (asyncronously)
+		go func(i int, x, y []byte) {
+			defer wg.Done()
+			e.Transcrypt(i, x, y)
+		}(i/e.keyLen, in[i:i+l], out[i:i+l])
+	}
+	// pad the rest if out is longer
+	pad(out[inLength:])
+	wg.Wait()
 }
 
-func (e *encryption) transform(data []byte, key Key) []byte {
-	dataLength := len(data)
-	transformedData := make([]byte, dataLength)
+// used for segmentwise transformation
+// if in is shorter than out, padding is used
+func (e *encryption) Transcrypt(i int, in []byte, out []byte) {
+	// first hash key with counter (initial counter + i)
 	hasher := e.hashFunc()
-	ctr := e.initCtr
-	hashSize := hasher.Size()
-	for i := 0; i < dataLength; i += hashSize {
-		hasher.Write(key)
+	hasher.Write(e.key)
 
-		ctrBytes := make([]byte, 4)
-		binary.LittleEndian.PutUint32(ctrBytes, ctr)
+	ctrBytes := make([]byte, 4)
+	binary.LittleEndian.PutUint32(ctrBytes, uint32(i)+e.initCtr)
+	hasher.Write(ctrBytes)
 
-		hasher.Write(ctrBytes)
+	ctrHash := hasher.Sum(nil)
+	hasher.Reset()
 
-		ctrHash := hasher.Sum(nil)
-		hasher.Reset()
-		hasher.Write(ctrHash)
+	// second round of hashing for selective disclosure
+	hasher.Write(ctrHash)
+	segmentKey := hasher.Sum(nil)
+	hasher.Reset()
 
-		segmentKey := hasher.Sum(nil)
-
-		hasher.Reset()
+	// XOR bytes uptil length of in (out must be at least as long)
+	inLength := len(in)
+	for j := 0; j < inLength; j++ {
+		out[j] = in[j] ^ segmentKey[j]
+	}
+	// insert padding if out is longer
+	pad(out[inLength:])
+}
 
-		segmentSize := min(hashSize, dataLength-i)
-		for j := 0; j < segmentSize; j++ {
-			transformedData[i+j] = data[i+j] ^ segmentKey[j]
-		}
-		ctr++
+func pad(b []byte) {
+	l := len(b)
+	for total := 0; total < l; {
+		read, _ := rand.Read(b[total:])
+		total += read
 	}
-	return transformedData
 }
 
-func GenerateRandomKey() (Key, error) {
-	key := make([]byte, KeyLength)
-	_, err := rand.Read(key)
-	return key, err
+// GenerateRandomKey generates a random key of length l
+func GenerateRandomKey(l int) Key {
+	key := make([]byte, l)
+	var total int
+	for total < l {
+		read, _ := rand.Read(key[total:])
+		total += read
+	}
+	return key
 }
 
 func min(x, y int) int {

文件差异内容过多而无法显示
+ 1 - 0
swarm/storage/encryption/encryption_test.go


+ 52 - 48
swarm/storage/hasherstore.go

@@ -26,49 +26,34 @@ import (
 	"github.com/ethereum/go-ethereum/swarm/storage/encryption"
 )
 
-type chunkEncryption struct {
-	spanEncryption encryption.Encryption
-	dataEncryption encryption.Encryption
-}
-
 type hasherStore struct {
-	store           ChunkStore
-	hashFunc        SwarmHasher
-	chunkEncryption *chunkEncryption
-	hashSize        int   // content hash size
-	refSize         int64 // reference size (content hash + possibly encryption key)
-	wg              *sync.WaitGroup
-	closed          chan struct{}
-}
-
-func newChunkEncryption(chunkSize, refSize int64) *chunkEncryption {
-	return &chunkEncryption{
-		spanEncryption: encryption.New(0, uint32(chunkSize/refSize), sha3.NewKeccak256),
-		dataEncryption: encryption.New(int(chunkSize), 0, sha3.NewKeccak256),
-	}
+	store     ChunkStore
+	toEncrypt bool
+	hashFunc  SwarmHasher
+	hashSize  int   // content hash size
+	refSize   int64 // reference size (content hash + possibly encryption key)
+	wg        *sync.WaitGroup
+	closed    chan struct{}
 }
 
 // NewHasherStore creates a hasherStore object, which implements Putter and Getter interfaces.
 // With the HasherStore you can put and get chunk data (which is just []byte) into a ChunkStore
 // and the hasherStore will take core of encryption/decryption of data if necessary
 func NewHasherStore(chunkStore ChunkStore, hashFunc SwarmHasher, toEncrypt bool) *hasherStore {
-	var chunkEncryption *chunkEncryption
-
 	hashSize := hashFunc().Size()
 	refSize := int64(hashSize)
 	if toEncrypt {
 		refSize += encryption.KeyLength
-		chunkEncryption = newChunkEncryption(chunk.DefaultSize, refSize)
 	}
 
 	return &hasherStore{
-		store:           chunkStore,
-		hashFunc:        hashFunc,
-		chunkEncryption: chunkEncryption,
-		hashSize:        hashSize,
-		refSize:         refSize,
-		wg:              &sync.WaitGroup{},
-		closed:          make(chan struct{}),
+		store:     chunkStore,
+		toEncrypt: toEncrypt,
+		hashFunc:  hashFunc,
+		hashSize:  hashSize,
+		refSize:   refSize,
+		wg:        &sync.WaitGroup{},
+		closed:    make(chan struct{}),
 	}
 }
 
@@ -79,7 +64,7 @@ func (h *hasherStore) Put(ctx context.Context, chunkData ChunkData) (Reference,
 	c := chunkData
 	size := chunkData.Size()
 	var encryptionKey encryption.Key
-	if h.chunkEncryption != nil {
+	if h.toEncrypt {
 		var err error
 		c, encryptionKey, err = h.encryptChunkData(chunkData)
 		if err != nil {
@@ -155,23 +140,14 @@ func (h *hasherStore) encryptChunkData(chunkData ChunkData) (ChunkData, encrypti
 		return nil, nil, fmt.Errorf("Invalid ChunkData, min length 8 got %v", len(chunkData))
 	}
 
-	encryptionKey, err := encryption.GenerateRandomKey()
-	if err != nil {
-		return nil, nil, err
-	}
-
-	encryptedSpan, err := h.chunkEncryption.spanEncryption.Encrypt(chunkData[:8], encryptionKey)
-	if err != nil {
-		return nil, nil, err
-	}
-	encryptedData, err := h.chunkEncryption.dataEncryption.Encrypt(chunkData[8:], encryptionKey)
+	key, encryptedSpan, encryptedData, err := h.encrypt(chunkData)
 	if err != nil {
 		return nil, nil, err
 	}
 	c := make(ChunkData, len(encryptedSpan)+len(encryptedData))
 	copy(c[:8], encryptedSpan)
 	copy(c[8:], encryptedData)
-	return c, encryptionKey, nil
+	return c, key, nil
 }
 
 func (h *hasherStore) decryptChunkData(chunkData ChunkData, encryptionKey encryption.Key) (ChunkData, error) {
@@ -179,12 +155,7 @@ func (h *hasherStore) decryptChunkData(chunkData ChunkData, encryptionKey encryp
 		return nil, fmt.Errorf("Invalid ChunkData, min length 8 got %v", len(chunkData))
 	}
 
-	decryptedSpan, err := h.chunkEncryption.spanEncryption.Decrypt(chunkData[:8], encryptionKey)
-	if err != nil {
-		return nil, err
-	}
-
-	decryptedData, err := h.chunkEncryption.dataEncryption.Decrypt(chunkData[8:], encryptionKey)
+	decryptedSpan, decryptedData, err := h.decrypt(chunkData, encryptionKey)
 	if err != nil {
 		return nil, err
 	}
@@ -201,13 +172,46 @@ func (h *hasherStore) decryptChunkData(chunkData ChunkData, encryptionKey encryp
 	copy(c[:8], decryptedSpan)
 	copy(c[8:], decryptedData[:length])
 
-	return c[:length+8], nil
+	return c, nil
 }
 
 func (h *hasherStore) RefSize() int64 {
 	return h.refSize
 }
 
+func (h *hasherStore) encrypt(chunkData ChunkData) (encryption.Key, []byte, []byte, error) {
+	key := encryption.GenerateRandomKey(encryption.KeyLength)
+	encryptedSpan, err := h.newSpanEncryption(key).Encrypt(chunkData[:8])
+	if err != nil {
+		return nil, nil, nil, err
+	}
+	encryptedData, err := h.newDataEncryption(key).Encrypt(chunkData[8:])
+	if err != nil {
+		return nil, nil, nil, err
+	}
+	return key, encryptedSpan, encryptedData, nil
+}
+
+func (h *hasherStore) decrypt(chunkData ChunkData, key encryption.Key) ([]byte, []byte, error) {
+	encryptedSpan, err := h.newSpanEncryption(key).Encrypt(chunkData[:8])
+	if err != nil {
+		return nil, nil, err
+	}
+	encryptedData, err := h.newDataEncryption(key).Encrypt(chunkData[8:])
+	if err != nil {
+		return nil, nil, err
+	}
+	return encryptedSpan, encryptedData, nil
+}
+
+func (h *hasherStore) newSpanEncryption(key encryption.Key) encryption.Encryption {
+	return encryption.New(key, 0, uint32(chunk.DefaultSize/h.refSize), sha3.NewKeccak256)
+}
+
+func (h *hasherStore) newDataEncryption(key encryption.Key) encryption.Encryption {
+	return encryption.New(key, int(chunk.DefaultSize), 0, sha3.NewKeccak256)
+}
+
 func (h *hasherStore) storeChunk(ctx context.Context, chunk *Chunk) {
 	h.wg.Add(1)
 	go func() {

部分文件因为文件数量过多而无法显示