Move Encrypt/Decrypt Into helper to session_state.go
This helper method is only applicable for Base64 wrapped encryption since it operated on string -> string primarily. It wouldn't be used for pure CFB/GCM ciphers. After a messagePack session refactor, this method would further only be used for legacy session compatibility - making its placement in cipher.go not ideal.
This commit is contained in:
		
							parent
							
								
									014fa682be
								
							
						
					
					
						commit
						1979627534
					
				|  | @ -77,7 +77,7 @@ func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) { | ||||||
| 			&ss.IDToken, | 			&ss.IDToken, | ||||||
| 			&ss.RefreshToken, | 			&ss.RefreshToken, | ||||||
| 		} { | 		} { | ||||||
| 			err := c.EncryptInto(s) | 			err := into(s, c.Encrypt) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return "", err | 				return "", err | ||||||
| 			} | 			} | ||||||
|  | @ -106,13 +106,13 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||||
| 	} else { | 	} else { | ||||||
| 		// Backward compatibility with using unencrypted Email or User
 | 		// Backward compatibility with using unencrypted Email or User
 | ||||||
| 		// Decryption errors will leave original string
 | 		// Decryption errors will leave original string
 | ||||||
| 		err = c.DecryptInto(&ss.Email) | 		err = into(&ss.Email, c.Decrypt) | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			if !utf8.ValidString(ss.Email) { | 			if !utf8.ValidString(ss.Email) { | ||||||
| 				return nil, errors.New("invalid value for decrypted email") | 				return nil, errors.New("invalid value for decrypted email") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		err = c.DecryptInto(&ss.User) | 		err = into(&ss.User, c.Decrypt) | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			if !utf8.ValidString(ss.User) { | 			if !utf8.ValidString(ss.User) { | ||||||
| 				return nil, errors.New("invalid value for decrypted user") | 				return nil, errors.New("invalid value for decrypted user") | ||||||
|  | @ -125,7 +125,7 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||||
| 			&ss.IDToken, | 			&ss.IDToken, | ||||||
| 			&ss.RefreshToken, | 			&ss.RefreshToken, | ||||||
| 		} { | 		} { | ||||||
| 			err := c.DecryptInto(s) | 			err := into(s, c.Decrypt) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, err | 				return nil, err | ||||||
| 			} | 			} | ||||||
|  | @ -133,3 +133,20 @@ func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { | ||||||
| 	} | 	} | ||||||
| 	return &ss, nil | 	return &ss, nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // codecFunc is a function that takes a []byte and encodes/decodes it
 | ||||||
|  | type codecFunc func([]byte) ([]byte, error) | ||||||
|  | 
 | ||||||
|  | func into(s *string, f codecFunc) error { | ||||||
|  | 	// Do not encrypt/decrypt nil or empty strings
 | ||||||
|  | 	if s == nil || *s == "" { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	d, err := f([]byte(*s)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	*s = string(d) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -1,11 +1,13 @@ | ||||||
| package sessions_test | package sessions | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"crypto/rand" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	mathrand "math/rand" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  | @ -26,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := newTestCipher([]byte(altSecret)) | 	c2, err := newTestCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &sessions.SessionState{ | 	s := &SessionState{ | ||||||
| 		Email:             "user@domain.com", | 		Email:             "user@domain.com", | ||||||
| 		PreferredUsername: "user", | 		PreferredUsername: "user", | ||||||
| 		AccessToken:       "token1234", | 		AccessToken:       "token1234", | ||||||
|  | @ -38,7 +40,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	encoded, err := s.EncodeSessionState(c) | 	encoded, err := s.EncodeSessionState(c) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	ss, err := sessions.DecodeSessionState(encoded, c) | 	ss, err := DecodeSessionState(encoded, c) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "", ss.User) | 	assert.Equal(t, "", ss.User) | ||||||
|  | @ -51,7 +53,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||||
| 
 | 
 | ||||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | ||||||
| 	ss, err = sessions.DecodeSessionState(encoded, c2) | 	ss, err = DecodeSessionState(encoded, c2) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| } | } | ||||||
|  | @ -61,7 +63,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := newTestCipher([]byte(altSecret)) | 	c2, err := newTestCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &sessions.SessionState{ | 	s := &SessionState{ | ||||||
| 		User:              "just-user", | 		User:              "just-user", | ||||||
| 		PreferredUsername: "ju", | 		PreferredUsername: "ju", | ||||||
| 		Email:             "user@domain.com", | 		Email:             "user@domain.com", | ||||||
|  | @ -73,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	encoded, err := s.EncodeSessionState(c) | 	encoded, err := s.EncodeSessionState(c) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	ss, err := sessions.DecodeSessionState(encoded, c) | 	ss, err := DecodeSessionState(encoded, c) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, s.User, ss.User) | 	assert.Equal(t, s.User, ss.User) | ||||||
|  | @ -85,13 +87,13 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||||
| 
 | 
 | ||||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | ||||||
| 	ss, err = sessions.DecodeSessionState(encoded, c2) | 	ss, err = DecodeSessionState(encoded, c2) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionStateSerializationNoCipher(t *testing.T) { | func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||||
| 	s := &sessions.SessionState{ | 	s := &SessionState{ | ||||||
| 		Email:             "user@domain.com", | 		Email:             "user@domain.com", | ||||||
| 		PreferredUsername: "user", | 		PreferredUsername: "user", | ||||||
| 		AccessToken:       "token1234", | 		AccessToken:       "token1234", | ||||||
|  | @ -103,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	// only email should have been serialized
 | 	// only email should have been serialized
 | ||||||
| 	ss, err := sessions.DecodeSessionState(encoded, nil) | 	ss, err := DecodeSessionState(encoded, nil) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "", ss.User) | 	assert.Equal(t, "", ss.User) | ||||||
| 	assert.Equal(t, s.Email, ss.Email) | 	assert.Equal(t, s.Email, ss.Email) | ||||||
|  | @ -113,7 +115,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||||
| 	s := &sessions.SessionState{ | 	s := &SessionState{ | ||||||
| 		User:              "just-user", | 		User:              "just-user", | ||||||
| 		Email:             "user@domain.com", | 		Email:             "user@domain.com", | ||||||
| 		PreferredUsername: "user", | 		PreferredUsername: "user", | ||||||
|  | @ -126,7 +128,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	// only email should have been serialized
 | 	// only email should have been serialized
 | ||||||
| 	ss, err := sessions.DecodeSessionState(encoded, nil) | 	ss, err := DecodeSessionState(encoded, nil) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, s.User, ss.User) | 	assert.Equal(t, s.User, ss.User) | ||||||
| 	assert.Equal(t, s.Email, ss.Email) | 	assert.Equal(t, s.Email, ss.Email) | ||||||
|  | @ -136,18 +138,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestExpired(t *testing.T) { | func TestExpired(t *testing.T) { | ||||||
| 	s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} | 	s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} | ||||||
| 	assert.Equal(t, true, s.IsExpired()) | 	assert.Equal(t, true, s.IsExpired()) | ||||||
| 
 | 
 | ||||||
| 	s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} | 	s = &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} | ||||||
| 	assert.Equal(t, false, s.IsExpired()) | 	assert.Equal(t, false, s.IsExpired()) | ||||||
| 
 | 
 | ||||||
| 	s = &sessions.SessionState{} | 	s = &SessionState{} | ||||||
| 	assert.Equal(t, false, s.IsExpired()) | 	assert.Equal(t, false, s.IsExpired()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type testCase struct { | type testCase struct { | ||||||
| 	sessions.SessionState | 	SessionState | ||||||
| 	Encoded string | 	Encoded string | ||||||
| 	Cipher  encryption.Cipher | 	Cipher  encryption.Cipher | ||||||
| 	Error   bool | 	Error   bool | ||||||
|  | @ -163,14 +165,14 @@ func TestEncodeSessionState(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	testCases := []testCase{ | 	testCases := []testCase{ | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				Email:        "user@domain.com", | 				Email:        "user@domain.com", | ||||||
| 				User:         "just-user", | 				User:         "just-user", | ||||||
| 				AccessToken:  "token1234", | 				AccessToken:  "token1234", | ||||||
|  | @ -185,7 +187,7 @@ func TestEncodeSessionState(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	for i, tc := range testCases { | 	for i, tc := range testCases { | ||||||
| 		encoded, err := tc.EncodeSessionState(tc.Cipher) | 		encoded, err := tc.EncodeSessionState(tc.Cipher) | ||||||
| 		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) | 		t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) | ||||||
| 		if tc.Error { | 		if tc.Error { | ||||||
| 			assert.Error(t, err) | 			assert.Error(t, err) | ||||||
| 			assert.Empty(t, encoded) | 			assert.Empty(t, encoded) | ||||||
|  | @ -210,34 +212,34 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	testCases := []testCase{ | 	testCases := []testCase{ | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "", | 				User:  "", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: `{"Email":"user@domain.com"}`, | 			Encoded: `{"Email":"user@domain.com"}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				User: "just-user", | 				User: "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: `{"User":"just-user"}`, | 			Encoded: `{"User":"just-user"}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), | 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				Email:        "user@domain.com", | 				Email:        "user@domain.com", | ||||||
| 				User:         "just-user", | 				User:         "just-user", | ||||||
| 				AccessToken:  "token1234", | 				AccessToken:  "token1234", | ||||||
|  | @ -250,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 			Cipher:  c, | 			Cipher:  c, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
|  | @ -268,7 +270,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 			Error:   true, | 			Error:   true, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user
 | 				User:  "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user
 | ||||||
| 			}, | 			}, | ||||||
|  | @ -278,8 +280,8 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for i, tc := range testCases { | 	for i, tc := range testCases { | ||||||
| 		ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher) | 		ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) | ||||||
| 		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) | 		t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err) | ||||||
| 		if tc.Error { | 		if tc.Error { | ||||||
| 			assert.Error(t, err) | 			assert.Error(t, err) | ||||||
| 			assert.Nil(t, ss) | 			assert.Nil(t, ss) | ||||||
|  | @ -301,7 +303,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionStateAge(t *testing.T) { | func TestSessionStateAge(t *testing.T) { | ||||||
| 	ss := &sessions.SessionState{} | 	ss := &SessionState{} | ||||||
| 
 | 
 | ||||||
| 	// Created at unset so should be 0
 | 	// Created at unset so should be 0
 | ||||||
| 	assert.Equal(t, time.Duration(0), ss.Age()) | 	assert.Equal(t, time.Duration(0), ss.Age()) | ||||||
|  | @ -310,3 +312,44 @@ func TestSessionStateAge(t *testing.T) { | ||||||
| 	ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour)) | 	ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour)) | ||||||
| 	assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) | 	assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestIntoEncryptAndIntoDecrypt(t *testing.T) { | ||||||
|  | 	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | ||||||
|  | 
 | ||||||
|  | 	// Test all 3 valid AES sizes
 | ||||||
|  | 	for _, secretSize := range []int{16, 24, 32} { | ||||||
|  | 		t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { | ||||||
|  | 			secret := make([]byte, secretSize) | ||||||
|  | 			_, err := io.ReadFull(rand.Reader, secret) | ||||||
|  | 			assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 			c, err := newTestCipher(secret) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 			// Check no errors with empty or nil strings
 | ||||||
|  | 			empty := "" | ||||||
|  | 			assert.Equal(t, nil, into(&empty, c.Encrypt)) | ||||||
|  | 			assert.Equal(t, nil, into(&empty, c.Decrypt)) | ||||||
|  | 			assert.Equal(t, nil, into(nil, c.Encrypt)) | ||||||
|  | 			assert.Equal(t, nil, into(nil, c.Decrypt)) | ||||||
|  | 
 | ||||||
|  | 			// Test various sizes tokens might be
 | ||||||
|  | 			for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { | ||||||
|  | 				t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { | ||||||
|  | 					b := make([]byte, dataSize) | ||||||
|  | 					for i := range b { | ||||||
|  | 						b[i] = charset[mathrand.Intn(len(charset))] | ||||||
|  | 					} | ||||||
|  | 					data := string(b) | ||||||
|  | 					originalData := data | ||||||
|  | 
 | ||||||
|  | 					assert.Equal(t, nil, into(&data, c.Encrypt)) | ||||||
|  | 					assert.NotEqual(t, originalData, data) | ||||||
|  | 
 | ||||||
|  | 					assert.Equal(t, nil, into(&data, c.Decrypt)) | ||||||
|  | 					assert.Equal(t, originalData, data) | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -13,11 +13,9 @@ import ( | ||||||
| type Cipher interface { | type Cipher interface { | ||||||
| 	Encrypt(value []byte) ([]byte, error) | 	Encrypt(value []byte) ([]byte, error) | ||||||
| 	Decrypt(ciphertext []byte) ([]byte, error) | 	Decrypt(ciphertext []byte) ([]byte, error) | ||||||
| 	EncryptInto(s *string) error |  | ||||||
| 	DecryptInto(s *string) error |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type Base64Cipher struct { | type base64Cipher struct { | ||||||
| 	Cipher Cipher | 	Cipher Cipher | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -28,11 +26,11 @@ func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Ci | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return &Base64Cipher{Cipher: c}, nil | 	return &base64Cipher{Cipher: c}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it
 | // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it
 | ||||||
| func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | func (c *base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| 	encrypted, err := c.Cipher.Encrypt(value) | 	encrypted, err := c.Cipher.Encrypt(value) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -42,7 +40,7 @@ func (c *Base64Cipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Decrypt Base64 decodes a value & decrypts it with the embedded Cipher
 | // Decrypt Base64 decodes a value & decrypts it with the embedded Cipher
 | ||||||
| func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | func (c *base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) | 	encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to base64 decode value %s", err) | 		return nil, fmt.Errorf("failed to base64 decode value %s", err) | ||||||
|  | @ -51,17 +49,7 @@ func (c *Base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	return c.Cipher.Decrypt(encrypted) | 	return c.Cipher.Decrypt(encrypted) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // EncryptInto encrypts the value and stores it back in the string pointer
 | type cfbCipher struct { | ||||||
| func (c *Base64Cipher) EncryptInto(s *string) error { |  | ||||||
| 	return into(c.Encrypt, s) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // DecryptInto decrypts the value and stores it back in the string pointer
 |  | ||||||
| func (c *Base64Cipher) DecryptInto(s *string) error { |  | ||||||
| 	return into(c.Decrypt, s) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type CFBCipher struct { |  | ||||||
| 	cipher.Block | 	cipher.Block | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -71,11 +59,11 @@ func NewCFBCipher(secret []byte) (Cipher, error) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return &CFBCipher{Block: c}, err | 	return &cfbCipher{Block: c}, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Encrypt with AES CFB
 | // Encrypt with AES CFB
 | ||||||
| func (c *CFBCipher) Encrypt(value []byte) ([]byte, error) { | func (c *cfbCipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| 	ciphertext := make([]byte, aes.BlockSize+len(value)) | 	ciphertext := make([]byte, aes.BlockSize+len(value)) | ||||||
| 	iv := ciphertext[:aes.BlockSize] | 	iv := ciphertext[:aes.BlockSize] | ||||||
| 	if _, err := io.ReadFull(rand.Reader, iv); err != nil { | 	if _, err := io.ReadFull(rand.Reader, iv); err != nil { | ||||||
|  | @ -88,7 +76,7 @@ func (c *CFBCipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Decrypt an AES CFB ciphertext
 | // Decrypt an AES CFB ciphertext
 | ||||||
| func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | func (c *cfbCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	if len(ciphertext) < aes.BlockSize { | 	if len(ciphertext) < aes.BlockSize { | ||||||
| 		return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext)) | 		return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext)) | ||||||
| 	} | 	} | ||||||
|  | @ -101,17 +89,7 @@ func (c *CFBCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	return plaintext, nil | 	return plaintext, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able
 | type gcmCipher struct { | ||||||
| func (c *CFBCipher) EncryptInto(s *string) error { |  | ||||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // EncryptInto returns an error since the encrypted data needs to be a []byte
 |  | ||||||
| func (c *CFBCipher) DecryptInto(s *string) error { |  | ||||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type GCMCipher struct { |  | ||||||
| 	cipher.Block | 	cipher.Block | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -121,11 +99,11 @@ func NewGCMCipher(secret []byte) (Cipher, error) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return &GCMCipher{Block: c}, err | 	return &gcmCipher{Block: c}, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Encrypt with AES GCM on raw bytes
 | // Encrypt with AES GCM on raw bytes
 | ||||||
| func (c *GCMCipher) Encrypt(value []byte) ([]byte, error) { | func (c *gcmCipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| 	gcm, err := cipher.NewGCM(c.Block) | 	gcm, err := cipher.NewGCM(c.Block) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -141,7 +119,7 @@ func (c *GCMCipher) Encrypt(value []byte) ([]byte, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Decrypt an AES GCM ciphertext
 | // Decrypt an AES GCM ciphertext
 | ||||||
| func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { | func (c *gcmCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	gcm, err := cipher.NewGCM(c.Block) | 	gcm, err := cipher.NewGCM(c.Block) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -156,30 +134,3 @@ func (c *GCMCipher) Decrypt(ciphertext []byte) ([]byte, error) { | ||||||
| 	} | 	} | ||||||
| 	return plaintext, nil | 	return plaintext, nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // EncryptInto returns an error since the encrypted data is a []byte that isn't string cast-able
 |  | ||||||
| func (c *GCMCipher) EncryptInto(s *string) error { |  | ||||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // EncryptInto returns an error since the encrypted data needs to be a []byte
 |  | ||||||
| func (c *GCMCipher) DecryptInto(s *string) error { |  | ||||||
| 	return fmt.Errorf("CFBCipher is not a string->string compatible cipher") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // codecFunc is a function that takes a string and encodes/decodes it
 |  | ||||||
| type codecFunc func([]byte) ([]byte, error) |  | ||||||
| 
 |  | ||||||
| func into(f codecFunc, s *string) error { |  | ||||||
| 	// Do not encrypt/decrypt nil or empty strings
 |  | ||||||
| 	if s == nil || *s == "" { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	d, err := f([]byte(*s)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	*s = string(d) |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -5,7 +5,6 @@ import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	mathrand "math/rand" |  | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
|  | @ -118,80 +117,6 @@ func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) { | ||||||
| 	assert.NotEqual(t, encrypted, decrypted) | 	assert.NotEqual(t, encrypted, decrypted) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestEncryptIntoAndDecryptInto(t *testing.T) { |  | ||||||
| 	// Test our 2 cipher types
 |  | ||||||
| 	cipherInits := map[string]func([]byte) (Cipher, error){ |  | ||||||
| 		"CFB": NewCFBCipher, |  | ||||||
| 		"GCM": NewGCMCipher, |  | ||||||
| 	} |  | ||||||
| 	for name, initCipher := range cipherInits { |  | ||||||
| 		t.Run(name, func(t *testing.T) { |  | ||||||
| 			// Test all 3 valid AES sizes
 |  | ||||||
| 			for _, secretSize := range []int{16, 24, 32} { |  | ||||||
| 				t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { |  | ||||||
| 					secret := make([]byte, secretSize) |  | ||||||
| 					_, err := io.ReadFull(rand.Reader, secret) |  | ||||||
| 					assert.Equal(t, nil, err) |  | ||||||
| 
 |  | ||||||
| 					// Test Standard & Base64 wrapped
 |  | ||||||
| 					cstd, err := initCipher(secret) |  | ||||||
| 					assert.Equal(t, nil, err) |  | ||||||
| 
 |  | ||||||
| 					cb64, err := NewBase64Cipher(initCipher, secret) |  | ||||||
| 					assert.Equal(t, nil, err) |  | ||||||
| 
 |  | ||||||
| 					ciphers := map[string]Cipher{ |  | ||||||
| 						"Standard": cstd, |  | ||||||
| 						"Base64":   cb64, |  | ||||||
| 					} |  | ||||||
| 
 |  | ||||||
| 					for cName, c := range ciphers { |  | ||||||
| 						// Check no errors with empty or nil strings
 |  | ||||||
| 						if cName == "Base64" { |  | ||||||
| 							empty := "" |  | ||||||
| 							assert.Equal(t, nil, c.EncryptInto(&empty)) |  | ||||||
| 							assert.Equal(t, nil, c.DecryptInto(&empty)) |  | ||||||
| 							assert.Equal(t, nil, c.EncryptInto(nil)) |  | ||||||
| 							assert.Equal(t, nil, c.DecryptInto(nil)) |  | ||||||
| 						} |  | ||||||
| 
 |  | ||||||
| 						t.Run(cName, func(t *testing.T) { |  | ||||||
| 							// Test various sizes sessions might be
 |  | ||||||
| 							for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { |  | ||||||
| 								t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { |  | ||||||
| 									runEncryptIntoAndDecryptInto(t, c, cName, dataSize) |  | ||||||
| 								}) |  | ||||||
| 							} |  | ||||||
| 						}) |  | ||||||
| 					} |  | ||||||
| 				}) |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func runEncryptIntoAndDecryptInto(t *testing.T, c Cipher, cipherType string, dataSize int) { |  | ||||||
| 	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" |  | ||||||
| 	b := make([]byte, dataSize) |  | ||||||
| 	for i := range b { |  | ||||||
| 		b[i] = charset[mathrand.Intn(len(charset))] |  | ||||||
| 	} |  | ||||||
| 	data := string(b) |  | ||||||
| 	originalData := data |  | ||||||
| 
 |  | ||||||
| 	// Base64 is the only cipher that supports string->string Encrypt/Decrypt Into methods
 |  | ||||||
| 	if cipherType == "Base64" { |  | ||||||
| 		assert.Equal(t, nil, c.EncryptInto(&data)) |  | ||||||
| 		assert.NotEqual(t, originalData, data) |  | ||||||
| 
 |  | ||||||
| 		assert.Equal(t, nil, c.DecryptInto(&data)) |  | ||||||
| 		assert.Equal(t, originalData, data) |  | ||||||
| 	} else { |  | ||||||
| 		assert.NotEqual(t, nil, c.EncryptInto(&data)) |  | ||||||
| 		assert.NotEqual(t, nil, c.DecryptInto(&data)) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestDecryptCFBWrongSecret(t *testing.T) { | func TestDecryptCFBWrongSecret(t *testing.T) { | ||||||
| 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | 	secret1 := []byte("0123456789abcdefghijklmnopqrstuv") | ||||||
| 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | 	secret2 := []byte("9876543210abcdefghijklmnopqrstuv") | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue