diff --git a/pkg/encryption/cipher.go b/pkg/encryption/cipher.go index 1ae521ea..d0d31882 100644 --- a/pkg/encryption/cipher.go +++ b/pkg/encryption/cipher.go @@ -50,7 +50,7 @@ func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Ci return &Base64Cipher{Cipher: c}, nil } -// Encrypt encrypts a value with AES CFB & Base64 encodes it +// Encrypt encrypts a value with the embedded Cipher & Base64 encodes it func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { encrypted, err := c.Cipher.Encrypt([]byte(value)) if err != nil { @@ -60,7 +60,7 @@ func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { return []byte(base64.StdEncoding.EncodeToString(encrypted)), nil } -// Decrypt Base64 decodes a value & decrypts it with AES CFB +// Decrypt Base64 decodes a value & decrypts it with the embedded Cipher func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) if err != nil { @@ -103,12 +103,12 @@ func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext)) } - iv := ciphertext[:aes.BlockSize] - ciphertext = ciphertext[aes.BlockSize:] + iv, ciphertext := ciphertext[:aes.BlockSize], ciphertext[aes.BlockSize:] + plaintext := make([]byte, len(ciphertext)) stream := cipher.NewCFBDecrypter(c.Block, iv) - stream.XORKeyStream(ciphertext, ciphertext) + stream.XORKeyStream(plaintext, ciphertext) - return ciphertext, nil + return plaintext, nil } type GCMCipher struct { @@ -135,6 +135,8 @@ func (c *GCMCipher) Encrypt(value []byte) ([]byte, error) { if _, err = io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } + // Using nonce as Seal's dst argument results in it being the first + // chunk of bytes in the ciphertext. Decrypt retrieves the nonce/IV from this. ciphertext := gcm.Seal(nonce, nonce, value, nil) return ciphertext, nil } diff --git a/pkg/encryption/cipher_test.go b/pkg/encryption/cipher_test.go index f146017f..c49fb67a 100644 --- a/pkg/encryption/cipher_test.go +++ b/pkg/encryption/cipher_test.go @@ -84,9 +84,17 @@ func TestEncryptAndDecrypt(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, encrypted, data) + // Ensure our Decrypt function doesn't decrypt in place + immutable := make([]byte, len(encrypted)) + copy(immutable, encrypted) + decrypted, err := c.Decrypt(encrypted) assert.Equal(t, nil, err) + // Original data back assert.Equal(t, data, decrypted) + // Decrypt didn't operate in-place on []byte + assert.Equal(t, encrypted, immutable) + // Encrypt/Decrypt actually did something assert.NotEqual(t, encrypted, decrypted) }) }