Improve design of Base64Cipher wrapping other ciphers.
Have it take in a cipher init function as an argument. Remove the confusing `newCipher` method that matched legacy behavior and returns a Base64Cipher(CFBCipher) -- instead explicitly ask for that in the uses.
This commit is contained in:
		
							parent
							
								
									b6931aa4ea
								
							
						
					
					
						commit
						ce2e92bc57
					
				|  | @ -17,10 +17,14 @@ func timePtr(t time.Time) *time.Time { | ||||||
| 	return &t | 	return &t | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func NewCipher(secret []byte) (encryption.Cipher, error) { | ||||||
|  | 	return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestSessionStateSerialization(t *testing.T) { | func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	c, err := encryption.NewCipher([]byte(secret)) | 	c, err := NewCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := encryption.NewCipher([]byte(altSecret)) | 	c2, err := NewCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &sessions.SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		Email:             "user@domain.com", | 		Email:             "user@domain.com", | ||||||
|  | @ -53,9 +57,9 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionStateSerializationWithUser(t *testing.T) { | func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	c, err := encryption.NewCipher([]byte(secret)) | 	c, err := NewCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := encryption.NewCipher([]byte(altSecret)) | 	c2, err := NewCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &sessions.SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		User:              "just-user", | 		User:              "just-user", | ||||||
|  | @ -201,7 +205,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 	eJSON, _ := e.MarshalJSON() | 	eJSON, _ := e.MarshalJSON() | ||||||
| 	eString := string(eJSON) | 	eString := string(eJSON) | ||||||
| 
 | 
 | ||||||
| 	c, err := encryption.NewCipher([]byte(secret)) | 	c, err := NewCipher([]byte(secret)) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	testCases := []testCase{ | 	testCases := []testCase{ | ||||||
|  |  | ||||||
|  | @ -135,16 +135,6 @@ func (c *DefaultCipher) DecryptInto(s *string) error { | ||||||
| 	return into(c.Decrypt, s) | 	return into(c.Decrypt, s) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewCipher returns a new aes Cipher for encrypting cookie values
 |  | ||||||
| // This defaults to the Base64 Cipher to align with legacy Encrypt/Decrypt functionality
 |  | ||||||
| func NewCipher(secret []byte) (Cipher, error) { |  | ||||||
| 	cfb, err := NewCFBCipher(secret) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	return NewBase64Cipher(cfb) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type Base64Cipher struct { | type Base64Cipher struct { | ||||||
| 	DefaultCipher | 	DefaultCipher | ||||||
| 	Cipher Cipher | 	Cipher Cipher | ||||||
|  | @ -152,7 +142,11 @@ type Base64Cipher struct { | ||||||
| 
 | 
 | ||||||
| // NewBase64Cipher returns a new AES Cipher for encrypting cookie values
 | // NewBase64Cipher returns a new AES Cipher for encrypting cookie values
 | ||||||
| // and wrapping them in Base64 -- Supports Legacy encryption scheme
 | // and wrapping them in Base64 -- Supports Legacy encryption scheme
 | ||||||
| func NewBase64Cipher(c Cipher) (Cipher, error) { | func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Cipher, error) { | ||||||
|  | 	c, err := initCipher(secret) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 	return &Base64Cipher{Cipher: c}, nil | 	return &Base64Cipher{Cipher: c}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -170,7 +164,7 @@ func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) | 	encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to decrypt cookie value %s", err) | 		return nil, fmt.Errorf("failed to base64 decode value %s", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return c.Cipher.Decrypt(encrypted) | 	return c.Cipher.Decrypt(encrypted) | ||||||
|  | @ -206,9 +200,7 @@ func (c *CFBCipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| // Decrypt an AES CFB ciphertext
 | // Decrypt an AES CFB ciphertext
 | ||||||
| func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	if len(ciphertext) < aes.BlockSize { | 	if len(ciphertext) < aes.BlockSize { | ||||||
| 		return nil, fmt.Errorf("encrypted value should be "+ | 		return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext)) | ||||||
| 			"at least %d bytes, but is only %d bytes", |  | ||||||
| 			aes.BlockSize, len(ciphertext)) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	iv := ciphertext[:aes.BlockSize] | 	iv := ciphertext[:aes.BlockSize] | ||||||
|  |  | ||||||
|  | @ -102,7 +102,7 @@ func TestSignAndValidate(t *testing.T) { | ||||||
| func TestEncodeAndDecodeAccessToken(t *testing.T) { | func TestEncodeAndDecodeAccessToken(t *testing.T) { | ||||||
| 	const secret = "0123456789abcdefghijklmnopqrstuv" | 	const secret = "0123456789abcdefghijklmnopqrstuv" | ||||||
| 	const token = "my access token" | 	const token = "my access token" | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := NewBase64Cipher(NewCFBCipher, []byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	encoded, err := c.Encrypt([]byte(token)) | 	encoded, err := c.Encrypt([]byte(token)) | ||||||
|  | @ -121,7 +121,7 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	secret, err := base64.URLEncoding.DecodeString(secretBase64) | 	secret, err := base64.URLEncoding.DecodeString(secretBase64) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := NewBase64Cipher(NewCFBCipher, []byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	encoded, err := c.Encrypt([]byte(token)) | 	encoded, err := c.Encrypt([]byte(token)) | ||||||
|  | @ -135,14 +135,12 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestEncryptAndDecrypt(t *testing.T) { | func TestEncryptAndDecrypt(t *testing.T) { | ||||||
| 	var err error | 	// Test our 2 cipher types
 | ||||||
| 
 | 	for _, initCipher := range []func([]byte) (Cipher, error){NewCFBCipher, NewGCMCipher} { | ||||||
| 	// Test our 3 cipher types
 |  | ||||||
| 	for _, initCipher := range []func([]byte) (Cipher, error){NewCipher, NewCFBCipher, NewGCMCipher} { |  | ||||||
| 		// Test all 3 valid AES sizes
 | 		// Test all 3 valid AES sizes
 | ||||||
| 		for _, secretSize := range []int{16, 24, 32} { | 		for _, secretSize := range []int{16, 24, 32} { | ||||||
| 			secret := make([]byte, secretSize) | 			secret := make([]byte, secretSize) | ||||||
| 			_, err = io.ReadFull(rand.Reader, secret) | 			_, err := io.ReadFull(rand.Reader, secret) | ||||||
| 			assert.Equal(t, nil, err) | 			assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 			c, err := initCipher(secret) | 			c, err := initCipher(secret) | ||||||
|  | @ -167,27 +165,55 @@ func TestEncryptAndDecrypt(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestDecryptWrongSecret(t *testing.T) { | func TestEncryptAndDecryptBase64(t *testing.T) { | ||||||
|  | 	// Test our cipher types wrapped in Base64 encoder
 | ||||||
|  | 	for _, initCipher := range []func([]byte) (Cipher, error){NewCFBCipher, NewGCMCipher} { | ||||||
|  | 		// Test all 3 valid AES sizes
 | ||||||
|  | 		for _, secretSize := range []int{16, 24, 32} { | ||||||
|  | 			secret := make([]byte, secretSize) | ||||||
|  | 			_, err := io.ReadFull(rand.Reader, secret) | ||||||
|  | 			assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 			c, err := NewBase64Cipher(initCipher, secret) | ||||||
|  | 			assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 			// Test various sizes sessions might be
 | ||||||
|  | 			for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||||
|  | 				data := make([]byte, dataSize) | ||||||
|  | 				_, err := io.ReadFull(rand.Reader, data) | ||||||
|  | 				assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 				encrypted, err := c.Encrypt(data) | ||||||
|  | 				assert.Equal(t, nil, err) | ||||||
|  | 				assert.NotEqual(t, encrypted, data) | ||||||
|  | 
 | ||||||
|  | 				decrypted, err := c.Decrypt(encrypted) | ||||||
|  | 				assert.Equal(t, nil, err) | ||||||
|  | 				assert.Equal(t, data, decrypted) | ||||||
|  | 				assert.NotEqual(t, encrypted, decrypted) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestDecryptCFBWrongSecret(t *testing.T) { | ||||||
| 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | ||||||
| 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | ||||||
| 
 | 
 | ||||||
| 	// Test CFB & Base64 (GCM is authenticated, it errors differently)
 | 	c1, err := NewCFBCipher(secret1) | ||||||
| 	for _, initCipher := range []func([]byte) (Cipher, error){NewCipher, NewCFBCipher} { | 	assert.Equal(t, nil, err) | ||||||
| 		c1, err := initCipher(secret1) |  | ||||||
| 		assert.Equal(t, nil, err) |  | ||||||
| 
 | 
 | ||||||
| 		c2, err := initCipher(secret2) | 	c2, err := NewCFBCipher(secret2) | ||||||
| 		assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 		data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df") | 	data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df") | ||||||
| 
 | 
 | ||||||
| 		ciphertext, err := c1.Encrypt(data) | 	ciphertext, err := c1.Encrypt(data) | ||||||
| 		assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 		wrongData, err := c2.Decrypt(ciphertext) | 	wrongData, err := c2.Decrypt(ciphertext) | ||||||
| 		assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 		assert.NotEqual(t, data, wrongData) | 	assert.NotEqual(t, data, wrongData) | ||||||
| 	} |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestDecryptGCMWrongSecret(t *testing.T) { | func TestDecryptGCMWrongSecret(t *testing.T) { | ||||||
|  | @ -211,13 +237,11 @@ func TestDecryptGCMWrongSecret(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestIntermixCiphersErrors(t *testing.T) { | func TestIntermixCiphersErrors(t *testing.T) { | ||||||
| 	var err error |  | ||||||
| 
 |  | ||||||
| 	// Encrypt with GCM, Decrypt with CFB: Results in Garbage data
 | 	// Encrypt with GCM, Decrypt with CFB: Results in Garbage data
 | ||||||
| 	// Test all 3 valid AES sizes
 | 	// Test all 3 valid AES sizes
 | ||||||
| 	for _, secretSize := range []int{16, 24, 32} { | 	for _, secretSize := range []int{16, 24, 32} { | ||||||
| 		secret := make([]byte, secretSize) | 		secret := make([]byte, secretSize) | ||||||
| 		_, err = io.ReadFull(rand.Reader, secret) | 		_, err := io.ReadFull(rand.Reader, secret) | ||||||
| 		assert.Equal(t, nil, err) | 		assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 		gcm, err := NewGCMCipher(secret) | 		gcm, err := NewGCMCipher(secret) | ||||||
|  | @ -248,7 +272,7 @@ func TestIntermixCiphersErrors(t *testing.T) { | ||||||
| 	// Test all 3 valid AES sizes
 | 	// Test all 3 valid AES sizes
 | ||||||
| 	for _, secretSize := range []int{16, 24, 32} { | 	for _, secretSize := range []int{16, 24, 32} { | ||||||
| 		secret := make([]byte, secretSize) | 		secret := make([]byte, secretSize) | ||||||
| 		_, err = io.ReadFull(rand.Reader, secret) | 		_, err := io.ReadFull(rand.Reader, secret) | ||||||
| 		assert.Equal(t, nil, err) | 		assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 		gcm, err := NewGCMCipher(secret) | 		gcm, err := NewGCMCipher(secret) | ||||||
|  |  | ||||||
|  | @ -367,7 +367,7 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 				_, err := rand.Read(secret) | 				_, err := rand.Read(secret) | ||||||
| 				Expect(err).ToNot(HaveOccurred()) | 				Expect(err).ToNot(HaveOccurred()) | ||||||
| 				cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret) | 				cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret) | ||||||
| 				cipher, err := encryption.NewCipher(encryption.SecretBytes(cookieOpts.Secret)) | 				cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) | ||||||
| 				Expect(err).ToNot(HaveOccurred()) | 				Expect(err).ToNot(HaveOccurred()) | ||||||
| 				Expect(cipher).ToNot(BeNil()) | 				Expect(cipher).ToNot(BeNil()) | ||||||
| 				opts.Cipher = cipher | 				opts.Cipher = cipher | ||||||
|  |  | ||||||
|  | @ -62,7 +62,7 @@ func Validate(o *options.Options) error { | ||||||
| 					len(encryption.SecretBytes(o.Cookie.Secret)), suffix)) | 					len(encryption.SecretBytes(o.Cookie.Secret)), suffix)) | ||||||
| 		} else { | 		} else { | ||||||
| 			var err error | 			var err error | ||||||
| 			cipher, err = encryption.NewCipher(encryption.SecretBytes(o.Cookie.Secret)) | 			cipher, err = encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(o.Cookie.Secret)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err)) | 				msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err)) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue