Add EncryptInto/DecryptInto Unit Tests
This commit is contained in:
		
							parent
							
								
									e43c65cc76
								
							
						
					
					
						commit
						014fa682be
					
				|  | @ -55,6 +55,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v5.1.1 | ## Changes since v5.1.1 | ||||||
| 
 | 
 | ||||||
|  | - [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves) | ||||||
| - [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed) | - [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed) | ||||||
| - [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed) | - [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed) | ||||||
| - [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer) | - [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer) | ||||||
|  |  | ||||||
|  | @ -104,25 +104,19 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||||
| 			PreferredUsername: ss.PreferredUsername, | 			PreferredUsername: ss.PreferredUsername, | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		// Backward compatibility with using unencrypted Email
 | 		// Backward compatibility with using unencrypted Email or User
 | ||||||
| 		if ss.Email != "" { | 		// Decryption errors will leave original string
 | ||||||
| 			decryptedEmail, errEmail := stringDecrypt(ss.Email, c) | 		err = c.DecryptInto(&ss.Email) | ||||||
| 			if errEmail == nil { | 		if err == nil { | ||||||
| 				if !utf8.ValidString(decryptedEmail) { | 			if !utf8.ValidString(ss.Email) { | ||||||
| 				return nil, errors.New("invalid value for decrypted email") | 				return nil, errors.New("invalid value for decrypted email") | ||||||
| 			} | 			} | ||||||
| 				ss.Email = decryptedEmail |  | ||||||
| 		} | 		} | ||||||
| 		} | 		err = c.DecryptInto(&ss.User) | ||||||
| 		// Backward compatibility with using unencrypted User
 | 		if err == nil { | ||||||
| 		if ss.User != "" { | 			if !utf8.ValidString(ss.User) { | ||||||
| 			decryptedUser, errUser := stringDecrypt(ss.User, c) |  | ||||||
| 			if errUser == nil { |  | ||||||
| 				if !utf8.ValidString(decryptedUser) { |  | ||||||
| 				return nil, errors.New("invalid value for decrypted user") | 				return nil, errors.New("invalid value for decrypted user") | ||||||
| 			} | 			} | ||||||
| 				ss.User = decryptedUser |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for _, s := range []*string{ | 		for _, s := range []*string{ | ||||||
|  | @ -139,12 +133,3 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||||
| 	} | 	} | ||||||
| 	return &ss, nil | 	return &ss, nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // stringDecrypt wraps a Base64Cipher to make it string => string
 |  | ||||||
| func stringDecrypt(ciphertext string, c encryption.Cipher) (string, error) { |  | ||||||
| 	value, err := c.Decrypt([]byte(ciphertext)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	return string(value), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -17,14 +17,14 @@ func timePtr(t time.Time) *time.Time { | ||||||
| 	return &t | 	return &t | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewCipher(secret []byte) (encryption.Cipher, error) { | func newTestCipher(secret []byte) (encryption.Cipher, error) { | ||||||
| 	return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) | 	return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionStateSerialization(t *testing.T) { | func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := newTestCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := NewCipher([]byte(altSecret)) | 	c2, err := newTestCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &sessions.SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		Email:             "user@domain.com", | 		Email:             "user@domain.com", | ||||||
|  | @ -57,9 +57,9 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionStateSerializationWithUser(t *testing.T) { | func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := newTestCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := NewCipher([]byte(altSecret)) | 	c2, err := newTestCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &sessions.SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		User:              "just-user", | 		User:              "just-user", | ||||||
|  | @ -205,7 +205,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 	eJSON, _ := e.MarshalJSON() | 	eJSON, _ := e.MarshalJSON() | ||||||
| 	eString := string(eJSON) | 	eString := string(eJSON) | ||||||
| 
 | 
 | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := newTestCipher([]byte(secret)) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	testCases := []testCase{ | 	testCases := []testCase{ | ||||||
|  |  | ||||||
|  | @ -17,26 +17,7 @@ type Cipher interface { | ||||||
| 	DecryptInto(s *string) error | 	DecryptInto(s *string) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type DefaultCipher struct {} |  | ||||||
| 
 |  | ||||||
| // Encrypt is a dummy method for CommonCipher.EncryptInto support
 |  | ||||||
| func (c *DefaultCipher) Encrypt(value []byte) ([]byte, error) { return value, nil } |  | ||||||
| 
 |  | ||||||
| // Decrypt is a dummy method for CommonCipher.DecryptInto support
 |  | ||||||
| func (c *DefaultCipher) Decrypt(ciphertext []byte) ([]byte, error) { return ciphertext, nil } |  | ||||||
| 
 |  | ||||||
| // EncryptInto encrypts the value and stores it back in the string pointer
 |  | ||||||
| func (c *DefaultCipher) EncryptInto(s *string) error { |  | ||||||
| 	return into(c.Encrypt, s) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // DecryptInto decrypts the value and stores it back in the string pointer
 |  | ||||||
| func (c *DefaultCipher) DecryptInto(s *string) error { |  | ||||||
| 	return into(c.Decrypt, s) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type Base64Cipher struct { | type Base64Cipher struct { | ||||||
| 	DefaultCipher |  | ||||||
| 	Cipher Cipher | 	Cipher Cipher | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -52,7 +33,7 @@ func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Ci | ||||||
| 
 | 
 | ||||||
| // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it
 | // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it
 | ||||||
| func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| 	encrypted, err := c.Cipher.Encrypt([]byte(value)) | 	encrypted, err := c.Cipher.Encrypt(value) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -70,8 +51,17 @@ func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	return c.Cipher.Decrypt(encrypted) | 	return c.Cipher.Decrypt(encrypted) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // EncryptInto encrypts the value and stores it back in the string pointer
 | ||||||
|  | func (c *Base64Cipher) EncryptInto(s *string) error { | ||||||
|  | 	return into(c.Encrypt, s) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // DecryptInto decrypts the value and stores it back in the string pointer
 | ||||||
|  | func (c *Base64Cipher) DecryptInto(s *string) error { | ||||||
|  | 	return into(c.Decrypt, s) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type CFBCipher struct { | type CFBCipher struct { | ||||||
| 	DefaultCipher |  | ||||||
| 	cipher.Block | 	cipher.Block | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -111,8 +101,17 @@ func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	return plaintext, nil | 	return plaintext, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able
 | ||||||
|  | func (c *CFBCipher) EncryptInto(s *string) error { | ||||||
|  | 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // EncryptInto returns an error since the encrypted data needs to be a []byte
 | ||||||
|  | func (c *CFBCipher) DecryptInto(s *string) error { | ||||||
|  | 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type GCMCipher struct { | type GCMCipher struct { | ||||||
| 	DefaultCipher |  | ||||||
| 	cipher.Block | 	cipher.Block | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -158,6 +157,16 @@ func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	return plaintext, nil | 	return plaintext, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able
 | ||||||
|  | func (c *GCMCipher) EncryptInto(s *string) error { | ||||||
|  | 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // EncryptInto returns an error since the encrypted data needs to be a []byte
 | ||||||
|  | func (c *GCMCipher) DecryptInto(s *string) error { | ||||||
|  | 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // codecFunc is a function that takes a string and encodes/decodes it
 | // codecFunc is a function that takes a string and encodes/decodes it
 | ||||||
| type codecFunc func([]byte) ([]byte, error) | type codecFunc func([]byte) ([]byte, error) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
|  | 	mathrand "math/rand" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
|  | @ -117,6 +118,80 @@ func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) { | ||||||
| 	assert.NotEqual(t, encrypted, decrypted) | 	assert.NotEqual(t, encrypted, decrypted) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestEncryptIntoAndDecryptInto(t *testing.T) { | ||||||
|  | 	// Test our 2 cipher types
 | ||||||
|  | 	cipherInits := map[string]func([]byte) (Cipher, error){ | ||||||
|  | 		"CFB": NewCFBCipher, | ||||||
|  | 		"GCM": NewGCMCipher, | ||||||
|  | 	} | ||||||
|  | 	for name, initCipher := range cipherInits { | ||||||
|  | 		t.Run(name, func(t *testing.T) { | ||||||
|  | 			// Test all 3 valid AES sizes
 | ||||||
|  | 			for _, secretSize := range []int{16, 24, 32} { | ||||||
|  | 				t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { | ||||||
|  | 					secret := make([]byte, secretSize) | ||||||
|  | 					_, err := io.ReadFull(rand.Reader, secret) | ||||||
|  | 					assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 					// Test Standard & Base64 wrapped
 | ||||||
|  | 					cstd, err := initCipher(secret) | ||||||
|  | 					assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 					cb64, err := NewBase64Cipher(initCipher, secret) | ||||||
|  | 					assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 					ciphers := map[string]Cipher{ | ||||||
|  | 						"Standard": cstd, | ||||||
|  | 						"Base64":   cb64, | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					for cName, c := range ciphers { | ||||||
|  | 						// Check no errors with empty or nil strings
 | ||||||
|  | 						if cName == "Base64" { | ||||||
|  | 							empty := "" | ||||||
|  | 							assert.Equal(t, nil, c.EncryptInto(&empty)) | ||||||
|  | 							assert.Equal(t, nil, c.DecryptInto(&empty)) | ||||||
|  | 							assert.Equal(t, nil, c.EncryptInto(nil)) | ||||||
|  | 							assert.Equal(t, nil, c.DecryptInto(nil)) | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						t.Run(cName, func(t *testing.T) { | ||||||
|  | 							// Test various sizes sessions might be
 | ||||||
|  | 							for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||||
|  | 								t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { | ||||||
|  | 									runEncryptIntoAndDecryptInto(t, c, cName, dataSize) | ||||||
|  | 								}) | ||||||
|  | 							} | ||||||
|  | 						}) | ||||||
|  | 					} | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func runEncryptIntoAndDecryptInto(t *testing.T, c Cipher, cipherType string, dataSize int) { | ||||||
|  | 	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | ||||||
|  | 	b := make([]byte, dataSize) | ||||||
|  | 	for i := range b { | ||||||
|  | 		b[i] = charset[mathrand.Intn(len(charset))] | ||||||
|  | 	} | ||||||
|  | 	data := string(b) | ||||||
|  | 	originalData := data | ||||||
|  | 
 | ||||||
|  | 	// Base64 is the only cipher that supports string->string Encrypt/Decrypt Into methods
 | ||||||
|  | 	if cipherType == "Base64" { | ||||||
|  | 		assert.Equal(t, nil, c.EncryptInto(&data)) | ||||||
|  | 		assert.NotEqual(t, originalData, data) | ||||||
|  | 
 | ||||||
|  | 		assert.Equal(t, nil, c.DecryptInto(&data)) | ||||||
|  | 		assert.Equal(t, originalData, data) | ||||||
|  | 	} else { | ||||||
|  | 		assert.NotEqual(t, nil, c.EncryptInto(&data)) | ||||||
|  | 		assert.NotEqual(t, nil, c.DecryptInto(&data)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestDecryptCFBWrongSecret(t *testing.T) { | func TestDecryptCFBWrongSecret(t *testing.T) { | ||||||
| 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | ||||||
| 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | ||||||
|  | @ -228,25 +303,3 @@ func TestCFBtoGCMErrors(t *testing.T) { | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) { |  | ||||||
| 	const secret = "0123456789abcdefghijklmnopqrstuv" |  | ||||||
| 	c, err := NewCipher([]byte(secret)) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 
 |  | ||||||
| 	token := "my access token" |  | ||||||
| 	originalToken := token |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, nil, c.EncryptInto(&token)) |  | ||||||
| 	assert.NotEqual(t, originalToken, token) |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, nil, c.DecryptInto(&token)) |  | ||||||
| 	assert.Equal(t, originalToken, token) |  | ||||||
| 
 |  | ||||||
| 	// Check no errors with empty or nil strings
 |  | ||||||
| 	empty := "" |  | ||||||
| 	assert.Equal(t, nil, c.EncryptInto(&empty)) |  | ||||||
| 	assert.Equal(t, nil, c.DecryptInto(&empty)) |  | ||||||
| 	assert.Equal(t, nil, c.EncryptInto(nil)) |  | ||||||
| 	assert.Equal(t, nil, c.DecryptInto(nil)) |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -104,4 +104,3 @@ func checkHmac(input, expected string) bool { | ||||||
| 	} | 	} | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue