diff --git a/CHANGELOG.md b/CHANGELOG.md index f35a3b43..f3cca2b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,7 @@ ## Changes since v5.1.1 +- [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves) - [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed) - [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed) - [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer) diff --git a/pkg/apis/options/sessions.go b/pkg/apis/options/sessions.go index b490baf7..3b2a4d19 100644 --- a/pkg/apis/options/sessions.go +++ b/pkg/apis/options/sessions.go @@ -4,9 +4,9 @@ import "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" // SessionOptions contains configuration options for the SessionStore providers. type SessionOptions struct { - Type string `flag:"session-store-type" cfg:"session_store_type"` - Cipher *encryption.Cipher `cfg:",internal"` - Redis RedisStoreOptions `cfg:",squash"` + Type string `flag:"session-store-type" cfg:"session_store_type"` + Cipher encryption.Cipher `cfg:",internal"` + Redis RedisStoreOptions `cfg:",squash"` } // CookieSessionStoreType is used to indicate the CookieSessionStore should be diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index f2e6633e..44b91bd2 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -60,7 +60,7 @@ func (s *SessionState) String() string { } // EncodeSessionState returns string representation of the current session -func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) { +func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) { var ss SessionState if c == nil { // Store only Email and User when cipher is unavailable @@ -77,7 +77,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) &ss.IDToken, &ss.RefreshToken, } { - err := c.EncryptInto(s) + err := into(s, c.Encrypt) if err != nil { return "", err } @@ -89,7 +89,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) } // DecodeSessionState decodes the session cookie string into a SessionState -func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { +func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { var ss SessionState err := json.Unmarshal([]byte(v), &ss) if err != nil { @@ -104,24 +104,18 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { PreferredUsername: ss.PreferredUsername, } } else { - // Backward compatibility with using unencrypted Email - if ss.Email != "" { - decryptedEmail, errEmail := c.Decrypt(ss.Email) - if errEmail == nil { - if !utf8.ValidString(decryptedEmail) { - return nil, errors.New("invalid value for decrypted email") - } - ss.Email = decryptedEmail + // Backward compatibility with using unencrypted Email or User + // Decryption errors will leave original string + err = into(&ss.Email, c.Decrypt) + if err == nil { + if !utf8.ValidString(ss.Email) { + return nil, errors.New("invalid value for decrypted email") } } - // Backward compatibility with using unencrypted User - if ss.User != "" { - decryptedUser, errUser := c.Decrypt(ss.User) - if errUser == nil { - if !utf8.ValidString(decryptedUser) { - return nil, errors.New("invalid value for decrypted user") - } - ss.User = decryptedUser + err = into(&ss.User, c.Decrypt) + if err == nil { + if !utf8.ValidString(ss.User) { + return nil, errors.New("invalid value for decrypted user") } } @@ -131,7 +125,7 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { &ss.IDToken, &ss.RefreshToken, } { - err := c.DecryptInto(s) + err := into(s, c.Decrypt) if err != nil { return nil, err } @@ -139,3 +133,20 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { } return &ss, nil } + +// codecFunc is a function that takes a []byte and encodes/decodes it +type codecFunc func([]byte) ([]byte, error) + +func into(s *string, f codecFunc) error { + // Do not encrypt/decrypt nil or empty strings + if s == nil || *s == "" { + return nil + } + + d, err := f([]byte(*s)) + if err != nil { + return err + } + *s = string(d) + return nil +} diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 150e9c9d..3e9554c5 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -1,11 +1,13 @@ -package sessions_test +package sessions import ( + "crypto/rand" "fmt" + "io" + mathrand "math/rand" "testing" "time" - "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/stretchr/testify/assert" ) @@ -17,12 +19,16 @@ func timePtr(t time.Time) *time.Time { return &t } +func newTestCipher(secret []byte) (encryption.Cipher, error) { + return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) +} + func TestSessionStateSerialization(t *testing.T) { - c, err := encryption.NewCipher([]byte(secret)) + c, err := newTestCipher([]byte(secret)) assert.Equal(t, nil, err) - c2, err := encryption.NewCipher([]byte(altSecret)) + c2, err := newTestCipher([]byte(altSecret)) assert.Equal(t, nil, err) - s := &sessions.SessionState{ + s := &SessionState{ Email: "user@domain.com", PreferredUsername: "user", AccessToken: "token1234", @@ -34,7 +40,7 @@ func TestSessionStateSerialization(t *testing.T) { encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - ss, err := sessions.DecodeSessionState(encoded, c) + ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, "", ss.User) @@ -47,17 +53,17 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = sessions.DecodeSessionState(encoded, c2) + ss, err = DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.NotEqual(t, nil, err) } func TestSessionStateSerializationWithUser(t *testing.T) { - c, err := encryption.NewCipher([]byte(secret)) + c, err := newTestCipher([]byte(secret)) assert.Equal(t, nil, err) - c2, err := encryption.NewCipher([]byte(altSecret)) + c2, err := newTestCipher([]byte(altSecret)) assert.Equal(t, nil, err) - s := &sessions.SessionState{ + s := &SessionState{ User: "just-user", PreferredUsername: "ju", Email: "user@domain.com", @@ -69,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - ss, err := sessions.DecodeSessionState(encoded, c) + ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) @@ -81,13 +87,13 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = sessions.DecodeSessionState(encoded, c2) + ss, err = DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.NotEqual(t, nil, err) } func TestSessionStateSerializationNoCipher(t *testing.T) { - s := &sessions.SessionState{ + s := &SessionState{ Email: "user@domain.com", PreferredUsername: "user", AccessToken: "token1234", @@ -99,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { assert.Equal(t, nil, err) // only email should have been serialized - ss, err := sessions.DecodeSessionState(encoded, nil) + ss, err := DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, "", ss.User) assert.Equal(t, s.Email, ss.Email) @@ -109,7 +115,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { } func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { - s := &sessions.SessionState{ + s := &SessionState{ User: "just-user", Email: "user@domain.com", PreferredUsername: "user", @@ -122,7 +128,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { assert.Equal(t, nil, err) // only email should have been serialized - ss, err := sessions.DecodeSessionState(encoded, nil) + ss, err := DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) @@ -132,20 +138,20 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { } func TestExpired(t *testing.T) { - s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} + s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} assert.Equal(t, true, s.IsExpired()) - s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} + s = &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} assert.Equal(t, false, s.IsExpired()) - s = &sessions.SessionState{} + s = &SessionState{} assert.Equal(t, false, s.IsExpired()) } type testCase struct { - sessions.SessionState + SessionState Encoded string - Cipher *encryption.Cipher + Cipher encryption.Cipher Error bool } @@ -159,14 +165,14 @@ func TestEncodeSessionState(t *testing.T) { testCases := []testCase{ { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -181,7 +187,7 @@ func TestEncodeSessionState(t *testing.T) { for i, tc := range testCases { encoded, err := tc.EncodeSessionState(tc.Cipher) - t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) + t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) if tc.Error { assert.Error(t, err) assert.Empty(t, encoded) @@ -201,39 +207,39 @@ func TestDecodeSessionState(t *testing.T) { eJSON, _ := e.MarshalJSON() eString := string(eJSON) - c, err := encryption.NewCipher([]byte(secret)) + c, err := newTestCipher([]byte(secret)) assert.NoError(t, err) testCases := []testCase{ { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "", }, Encoded: `{"Email":"user@domain.com"}`, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ User: "just-user", }, Encoded: `{"User":"just-user"}`, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -246,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) { Cipher: c, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "just-user", }, @@ -264,7 +270,7 @@ func TestDecodeSessionState(t *testing.T) { Error: true, }, { - SessionState: sessions.SessionState{ + SessionState: SessionState{ Email: "user@domain.com", User: "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user }, @@ -274,8 +280,8 @@ func TestDecodeSessionState(t *testing.T) { } for i, tc := range testCases { - ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher) - t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) + ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) + t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err) if tc.Error { assert.Error(t, err) assert.Nil(t, ss) @@ -297,7 +303,7 @@ func TestDecodeSessionState(t *testing.T) { } func TestSessionStateAge(t *testing.T) { - ss := &sessions.SessionState{} + ss := &SessionState{} // Created at unset so should be 0 assert.Equal(t, time.Duration(0), ss.Age()) @@ -306,3 +312,44 @@ func TestSessionStateAge(t *testing.T) { ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour)) assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) } + +func TestIntoEncryptAndIntoDecrypt(t *testing.T) { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + // Test all 3 valid AES sizes + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + c, err := newTestCipher(secret) + assert.NoError(t, err) + + // Check no errors with empty or nil strings + empty := "" + assert.Equal(t, nil, into(&empty, c.Encrypt)) + assert.Equal(t, nil, into(&empty, c.Decrypt)) + assert.Equal(t, nil, into(nil, c.Encrypt)) + assert.Equal(t, nil, into(nil, c.Decrypt)) + + // Test various sizes tokens might be + for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { + t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { + b := make([]byte, dataSize) + for i := range b { + b[i] = charset[mathrand.Intn(len(charset))] + } + data := string(b) + originalData := data + + assert.Equal(t, nil, into(&data, c.Encrypt)) + assert.NotEqual(t, originalData, data) + + assert.Equal(t, nil, into(&data, c.Decrypt)) + assert.Equal(t, originalData, data) + }) + } + }) + } +} diff --git a/pkg/encryption/cipher.go b/pkg/encryption/cipher.go index 4eb42b03..c1158b5c 100644 --- a/pkg/encryption/cipher.go +++ b/pkg/encryption/cipher.go @@ -3,183 +3,134 @@ package encryption import ( "crypto/aes" "crypto/cipher" - "crypto/hmac" "crypto/rand" - "crypto/sha1" - "crypto/sha256" "encoding/base64" "fmt" - "hash" "io" - "net/http" - "strconv" - "strings" - "time" ) -// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary -func SecretBytes(secret string) []byte { - b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "=")) - if err == nil { - // Only return decoded form if a valid AES length - // Don't want unintentional decoding resulting in invalid lengths confusing a user - // that thought they used a 16, 24, 32 length string - for _, i := range []int{16, 24, 32} { - if len(b) == i { - return b - } - } - } - // If decoding didn't work or resulted in non-AES compliant length, - // assume the raw string was the intended secret - return []byte(secret) +// Cipher provides methods to encrypt and decrypt +type Cipher interface { + Encrypt(value []byte) ([]byte, error) + Decrypt(ciphertext []byte) ([]byte, error) } -// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set. -// additionally, the 'value' is encrypted so it's opaque to the browser - -// Validate ensures a cookie is properly signed -func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) { - // value, timestamp, sig - parts := strings.Split(cookie.Value, "|") - if len(parts) != 3 { - return - } - if checkSignature(parts[2], seed, cookie.Name, parts[0], parts[1]) { - ts, err := strconv.Atoi(parts[1]) - if err != nil { - return - } - // The expiration timestamp set when the cookie was created - // isn't sent back by the browser. Hence, we check whether the - // creation timestamp stored in the cookie falls within the - // window defined by (Now()-expiration, Now()]. - t = time.Unix(int64(ts), 0) - if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { - // it's a valid cookie. now get the contents - rawValue, err := base64.URLEncoding.DecodeString(parts[0]) - if err == nil { - value = string(rawValue) - ok = true - return - } - } - } - return +type base64Cipher struct { + Cipher Cipher } -// SignedValue returns a cookie that is signed and can later be checked with Validate -func SignedValue(seed string, key string, value string, now time.Time) string { - encodedValue := base64.URLEncoding.EncodeToString([]byte(value)) - timeStr := fmt.Sprintf("%d", now.Unix()) - sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr) - cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) - return cookieVal -} - -func cookieSignature(signer func() hash.Hash, args ...string) string { - h := hmac.New(signer, []byte(args[0])) - for _, arg := range args[1:] { - h.Write([]byte(arg)) +// NewBase64Cipher returns a new AES Cipher for encrypting cookie values +// and wrapping them in Base64 -- Supports Legacy encryption scheme +func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Cipher, error) { + c, err := initCipher(secret) + if err != nil { + return nil, err } - var b []byte - b = h.Sum(b) - return base64.URLEncoding.EncodeToString(b) + return &base64Cipher{Cipher: c}, nil } -func checkSignature(signature string, args ...string) bool { - checkSig := cookieSignature(sha256.New, args...) - if checkHmac(signature, checkSig) { - return true +// Encrypt encrypts a value with the embedded Cipher & Base64 encodes it +func (c *base64Cipher) Encrypt(value []byte) ([]byte, error) { + encrypted, err := c.Cipher.Encrypt(value) + if err != nil { + return nil, err } - // TODO: After appropriate rollout window, remove support for SHA1 - legacySig := cookieSignature(sha1.New, args...) - return checkHmac(signature, legacySig) + return []byte(base64.StdEncoding.EncodeToString(encrypted)), nil } -func checkHmac(input, expected string) bool { - inputMAC, err1 := base64.URLEncoding.DecodeString(input) - if err1 == nil { - expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) - if err2 == nil { - return hmac.Equal(inputMAC, expectedMAC) - } +// Decrypt Base64 decodes a value & decrypts it with the embedded Cipher +func (c *base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) { + encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext)) + if err != nil { + return nil, fmt.Errorf("failed to base64 decode value %s", err) } - return false + + return c.Cipher.Decrypt(encrypted) } -// Cipher provides methods to encrypt and decrypt cookie values -type Cipher struct { +type cfbCipher struct { cipher.Block } -// NewCipher returns a new aes Cipher for encrypting cookie values -func NewCipher(secret []byte) (*Cipher, error) { +// NewCFBCipher returns a new AES CFB Cipher +func NewCFBCipher(secret []byte) (Cipher, error) { c, err := aes.NewCipher(secret) if err != nil { return nil, err } - return &Cipher{Block: c}, err + return &cfbCipher{Block: c}, err } -// Encrypt a value for use in a cookie -func (c *Cipher) Encrypt(value string) (string, error) { +// Encrypt with AES CFB +func (c *cfbCipher) Encrypt(value []byte) ([]byte, error) { ciphertext := make([]byte, aes.BlockSize+len(value)) iv := ciphertext[:aes.BlockSize] if _, err := io.ReadFull(rand.Reader, iv); err != nil { - return "", fmt.Errorf("failed to create initialization vector %s", err) + return nil, fmt.Errorf("failed to create initialization vector %s", err) } stream := cipher.NewCFBEncrypter(c.Block, iv) - stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value)) - return base64.StdEncoding.EncodeToString(ciphertext), nil + stream.XORKeyStream(ciphertext[aes.BlockSize:], value) + return ciphertext, nil } -// Decrypt a value from a cookie to it's original string -func (c *Cipher) Decrypt(s string) (string, error) { - encrypted, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return "", fmt.Errorf("failed to decrypt cookie value %s", err) +// Decrypt an AES CFB ciphertext +func (c *cfbCipher) Decrypt(ciphertext []byte) ([]byte, error) { + if len(ciphertext) < aes.BlockSize { + return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext)) } - if len(encrypted) < aes.BlockSize { - return "", fmt.Errorf("encrypted cookie value should be "+ - "at least %d bytes, but is only %d bytes", - aes.BlockSize, len(encrypted)) - } - - iv := encrypted[:aes.BlockSize] - encrypted = encrypted[aes.BlockSize:] + iv, ciphertext := ciphertext[:aes.BlockSize], ciphertext[aes.BlockSize:] + plaintext := make([]byte, len(ciphertext)) stream := cipher.NewCFBDecrypter(c.Block, iv) - stream.XORKeyStream(encrypted, encrypted) + stream.XORKeyStream(plaintext, ciphertext) - return string(encrypted), nil + return plaintext, nil } -// EncryptInto encrypts the value and stores it back in the string pointer -func (c *Cipher) EncryptInto(s *string) error { - return into(c.Encrypt, s) +type gcmCipher struct { + cipher.Block } -// DecryptInto decrypts the value and stores it back in the string pointer -func (c *Cipher) DecryptInto(s *string) error { - return into(c.Decrypt, s) -} - -// codecFunc is a function that takes a string and encodes/decodes it -type codecFunc func(string) (string, error) - -func into(f codecFunc, s *string) error { - // Do not encrypt/decrypt nil or empty strings - if s == nil || *s == "" { - return nil - } - - d, err := f(*s) +// NewGCMCipher returns a new AES GCM Cipher +func NewGCMCipher(secret []byte) (Cipher, error) { + c, err := aes.NewCipher(secret) if err != nil { - return err + return nil, err } - *s = d - return nil + return &gcmCipher{Block: c}, err +} + +// Encrypt with AES GCM on raw bytes +func (c *gcmCipher) Encrypt(value []byte) ([]byte, error) { + gcm, err := cipher.NewGCM(c.Block) + if err != nil { + return nil, err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + // Using nonce as Seal's dst argument results in it being the first + // chunk of bytes in the ciphertext. Decrypt retrieves the nonce/IV from this. + ciphertext := gcm.Seal(nonce, nonce, value, nil) + return ciphertext, nil +} + +// Decrypt an AES GCM ciphertext +func (c *gcmCipher) Decrypt(ciphertext []byte) ([]byte, error) { + gcm, err := cipher.NewGCM(c.Block) + if err != nil { + return nil, err + } + + nonceSize := gcm.NonceSize() + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + return plaintext, nil } diff --git a/pkg/encryption/cipher_test.go b/pkg/encryption/cipher_test.go index aed529f3..b552e70c 100644 --- a/pkg/encryption/cipher_test.go +++ b/pkg/encryption/cipher_test.go @@ -2,8 +2,6 @@ package encryption import ( "crypto/rand" - "crypto/sha1" - "crypto/sha256" "encoding/base64" "fmt" "io" @@ -12,107 +10,20 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSecretBytesEncoded(t *testing.T) { - for _, secretSize := range []int{16, 24, 32} { - t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { - secret := make([]byte, secretSize) - _, err := io.ReadFull(rand.Reader, secret) - assert.Equal(t, nil, err) - - // We test both padded & raw Base64 to ensure we handle both - // potential user input routes for Base64 - base64Padded := base64.URLEncoding.EncodeToString(secret) - sb := SecretBytes(base64Padded) - assert.Equal(t, secret, sb) - assert.Equal(t, len(sb), secretSize) - - base64Raw := base64.RawURLEncoding.EncodeToString(secret) - sb = SecretBytes(base64Raw) - assert.Equal(t, secret, sb) - assert.Equal(t, len(sb), secretSize) - }) - } -} - -// A string that isn't intended as Base64 and still decodes (but to unintended length) -// will return the original secret as bytes -func TestSecretBytesEncodedWrongSize(t *testing.T) { - for _, secretSize := range []int{15, 20, 28, 33, 44} { - t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { - secret := make([]byte, secretSize) - _, err := io.ReadFull(rand.Reader, secret) - assert.Equal(t, nil, err) - - // We test both padded & raw Base64 to ensure we handle both - // potential user input routes for Base64 - base64Padded := base64.URLEncoding.EncodeToString(secret) - sb := SecretBytes(base64Padded) - assert.NotEqual(t, secret, sb) - assert.NotEqual(t, len(sb), secretSize) - // The given secret is returned as []byte - assert.Equal(t, base64Padded, string(sb)) - - base64Raw := base64.RawURLEncoding.EncodeToString(secret) - sb = SecretBytes(base64Raw) - assert.NotEqual(t, secret, sb) - assert.NotEqual(t, len(sb), secretSize) - // The given secret is returned as []byte - assert.Equal(t, base64Raw, string(sb)) - }) - } -} - -func TestSecretBytesNonBase64(t *testing.T) { - trailer := "equals==========" - assert.Equal(t, trailer, string(SecretBytes(trailer))) - - raw16 := "asdflkjhqwer)(*&" - sb16 := SecretBytes(raw16) - assert.Equal(t, raw16, string(sb16)) - assert.Equal(t, 16, len(sb16)) - - raw24 := "asdflkjhqwer)(*&CJEN#$%^" - sb24 := SecretBytes(raw24) - assert.Equal(t, raw24, string(sb24)) - assert.Equal(t, 24, len(sb24)) - - raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&" - sb32 := SecretBytes(raw32) - assert.Equal(t, raw32, string(sb32)) - assert.Equal(t, 32, len(sb32)) -} - -func TestSignAndValidate(t *testing.T) { - seed := "0123456789abcdef" - key := "cookie-name" - value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded")) - epoch := "123456789" - - sha256sig := cookieSignature(sha256.New, seed, key, value, epoch) - sha1sig := cookieSignature(sha1.New, seed, key, value, epoch) - - assert.True(t, checkSignature(sha256sig, seed, key, value, epoch)) - // This should be switched to False after fully deprecating SHA1 - assert.True(t, checkSignature(sha1sig, seed, key, value, epoch)) - - assert.False(t, checkSignature(sha256sig, seed, key, "tampered", epoch)) - assert.False(t, checkSignature(sha1sig, seed, key, "tampered", epoch)) -} - func TestEncodeAndDecodeAccessToken(t *testing.T) { const secret = "0123456789abcdefghijklmnopqrstuv" const token = "my access token" - c, err := NewCipher([]byte(secret)) + c, err := NewBase64Cipher(NewCFBCipher, []byte(secret)) assert.Equal(t, nil, err) - encoded, err := c.Encrypt(token) + encoded, err := c.Encrypt([]byte(token)) assert.Equal(t, nil, err) decoded, err := c.Decrypt(encoded) assert.Equal(t, nil, err) - assert.NotEqual(t, token, encoded) - assert.Equal(t, token, decoded) + assert.NotEqual(t, []byte(token), encoded) + assert.Equal(t, []byte(token), decoded) } func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { @@ -121,37 +32,199 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { secret, err := base64.URLEncoding.DecodeString(secretBase64) assert.Equal(t, nil, err) - c, err := NewCipher([]byte(secret)) + c, err := NewBase64Cipher(NewCFBCipher, []byte(secret)) assert.Equal(t, nil, err) - encoded, err := c.Encrypt(token) + encoded, err := c.Encrypt([]byte(token)) assert.Equal(t, nil, err) decoded, err := c.Decrypt(encoded) assert.Equal(t, nil, err) - assert.NotEqual(t, token, encoded) - assert.Equal(t, token, decoded) + assert.NotEqual(t, []byte(token), encoded) + assert.Equal(t, []byte(token), decoded) } -func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) { - const secret = "0123456789abcdefghijklmnopqrstuv" - c, err := NewCipher([]byte(secret)) +func TestEncryptAndDecrypt(t *testing.T) { + // Test our 2 cipher types + cipherInits := map[string]func([]byte) (Cipher, error){ + "CFB": NewCFBCipher, + "GCM": NewGCMCipher, + } + for name, initCipher := range cipherInits { + t.Run(name, func(t *testing.T) { + // Test all 3 valid AES sizes + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + // Test Standard & Base64 wrapped + cstd, err := initCipher(secret) + assert.Equal(t, nil, err) + + cb64, err := NewBase64Cipher(initCipher, secret) + assert.Equal(t, nil, err) + + ciphers := map[string]Cipher{ + "Standard": cstd, + "Base64": cb64, + } + + for cName, c := range ciphers { + t.Run(cName, func(t *testing.T) { + // Test various sizes sessions might be + for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { + t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { + runEncryptAndDecrypt(t, c, dataSize) + }) + } + }) + } + }) + } + }) + } +} + +func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) { + data := make([]byte, dataSize) + _, err := io.ReadFull(rand.Reader, data) assert.Equal(t, nil, err) - token := "my access token" - originalToken := token + // Ensure our Encrypt function doesn't encrypt in place + immutableData := make([]byte, len(data)) + copy(immutableData, data) - assert.Equal(t, nil, c.EncryptInto(&token)) - assert.NotEqual(t, originalToken, token) + encrypted, err := c.Encrypt(data) + assert.Equal(t, nil, err) + assert.NotEqual(t, encrypted, data) + // Encrypt didn't operate in-place on []byte + assert.Equal(t, data, immutableData) - assert.Equal(t, nil, c.DecryptInto(&token)) - assert.Equal(t, originalToken, token) + // Ensure our Decrypt function doesn't decrypt in place + immutableEnc := make([]byte, len(encrypted)) + copy(immutableEnc, encrypted) - // Check no errors with empty or nil strings - empty := "" - assert.Equal(t, nil, c.EncryptInto(&empty)) - assert.Equal(t, nil, c.DecryptInto(&empty)) - assert.Equal(t, nil, c.EncryptInto(nil)) - assert.Equal(t, nil, c.DecryptInto(nil)) + decrypted, err := c.Decrypt(encrypted) + assert.Equal(t, nil, err) + // Original data back + assert.Equal(t, data, decrypted) + // Decrypt didn't operate in-place on []byte + assert.Equal(t, encrypted, immutableEnc) + // Encrypt/Decrypt actually did something + assert.NotEqual(t, encrypted, decrypted) +} + +func TestDecryptCFBWrongSecret(t *testing.T) { + secret1 := []byte("0123456789abcdefghijklmnopqrstuv") + secret2 := []byte("9876543210abcdefghijklmnopqrstuv") + + c1, err := NewCFBCipher(secret1) + assert.Equal(t, nil, err) + + c2, err := NewCFBCipher(secret2) + assert.Equal(t, nil, err) + + data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df") + + ciphertext, err := c1.Encrypt(data) + assert.Equal(t, nil, err) + + wrongData, err := c2.Decrypt(ciphertext) + assert.Equal(t, nil, err) + assert.NotEqual(t, data, wrongData) +} + +func TestDecryptGCMWrongSecret(t *testing.T) { + secret1 := []byte("0123456789abcdefghijklmnopqrstuv") + secret2 := []byte("9876543210abcdefghijklmnopqrstuv") + + c1, err := NewGCMCipher(secret1) + assert.Equal(t, nil, err) + + c2, err := NewGCMCipher(secret2) + assert.Equal(t, nil, err) + + data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df") + + ciphertext, err := c1.Encrypt(data) + assert.Equal(t, nil, err) + + // GCM is authenticated - this should lead to message authentication failed + _, err = c2.Decrypt(ciphertext) + assert.Error(t, err) +} + +// Encrypt with GCM, Decrypt with CFB: Results in Garbage data +func TestGCMtoCFBErrors(t *testing.T) { + // Test all 3 valid AES sizes + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + gcm, err := NewGCMCipher(secret) + assert.Equal(t, nil, err) + + cfb, err := NewCFBCipher(secret) + assert.Equal(t, nil, err) + + // Test various sizes sessions might be + for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { + t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { + data := make([]byte, dataSize) + _, err := io.ReadFull(rand.Reader, data) + assert.Equal(t, nil, err) + + encrypted, err := gcm.Encrypt(data) + assert.Equal(t, nil, err) + assert.NotEqual(t, encrypted, data) + + decrypted, err := cfb.Decrypt(encrypted) + assert.Equal(t, nil, err) + // Data is mangled + assert.NotEqual(t, data, decrypted) + assert.NotEqual(t, encrypted, decrypted) + }) + } + }) + } +} + +// Encrypt with CFB, Decrypt with GCM: Results in errors +func TestCFBtoGCMErrors(t *testing.T) { + // Test all 3 valid AES sizes + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + gcm, err := NewGCMCipher(secret) + assert.Equal(t, nil, err) + + cfb, err := NewCFBCipher(secret) + assert.Equal(t, nil, err) + + // Test various sizes sessions might be + for _, dataSize := range []int{10, 100, 1000, 5000, 10000} { + t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) { + data := make([]byte, dataSize) + _, err := io.ReadFull(rand.Reader, data) + assert.Equal(t, nil, err) + + encrypted, err := cfb.Encrypt(data) + assert.Equal(t, nil, err) + assert.NotEqual(t, encrypted, data) + + // GCM is authenticated - this should lead to message authentication failed + _, err = gcm.Decrypt(encrypted) + assert.Error(t, err) + }) + } + }) + } } diff --git a/pkg/encryption/utils.go b/pkg/encryption/utils.go new file mode 100644 index 00000000..26be6b24 --- /dev/null +++ b/pkg/encryption/utils.go @@ -0,0 +1,106 @@ +package encryption + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "fmt" + "hash" + "net/http" + "strconv" + "strings" + "time" +) + +// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary +func SecretBytes(secret string) []byte { + b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "=")) + if err == nil { + // Only return decoded form if a valid AES length + // Don't want unintentional decoding resulting in invalid lengths confusing a user + // that thought they used a 16, 24, 32 length string + for _, i := range []int{16, 24, 32} { + if len(b) == i { + return b + } + } + } + // If decoding didn't work or resulted in non-AES compliant length, + // assume the raw string was the intended secret + return []byte(secret) +} + +// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set. +// additionally, the 'value' is encrypted so it's opaque to the browser + +// Validate ensures a cookie is properly signed +func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value []byte, t time.Time, ok bool) { + // value, timestamp, sig + parts := strings.Split(cookie.Value, "|") + if len(parts) != 3 { + return + } + if checkSignature(parts[2], seed, cookie.Name, parts[0], parts[1]) { + ts, err := strconv.Atoi(parts[1]) + if err != nil { + return + } + // The expiration timestamp set when the cookie was created + // isn't sent back by the browser. Hence, we check whether the + // creation timestamp stored in the cookie falls within the + // window defined by (Now()-expiration, Now()]. + t = time.Unix(int64(ts), 0) + if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { + // it's a valid cookie. now get the contents + rawValue, err := base64.URLEncoding.DecodeString(parts[0]) + if err == nil { + value = rawValue + ok = true + return + } + } + } + return +} + +// SignedValue returns a cookie that is signed and can later be checked with Validate +func SignedValue(seed string, key string, value []byte, now time.Time) string { + encodedValue := base64.URLEncoding.EncodeToString(value) + timeStr := fmt.Sprintf("%d", now.Unix()) + sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr) + cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) + return cookieVal +} + +func cookieSignature(signer func() hash.Hash, args ...string) string { + h := hmac.New(signer, []byte(args[0])) + for _, arg := range args[1:] { + h.Write([]byte(arg)) + } + var b []byte + b = h.Sum(b) + return base64.URLEncoding.EncodeToString(b) +} + +func checkSignature(signature string, args ...string) bool { + checkSig := cookieSignature(sha256.New, args...) + if checkHmac(signature, checkSig) { + return true + } + + // TODO: After appropriate rollout window, remove support for SHA1 + legacySig := cookieSignature(sha1.New, args...) + return checkHmac(signature, legacySig) +} + +func checkHmac(input, expected string) bool { + inputMAC, err1 := base64.URLEncoding.DecodeString(input) + if err1 == nil { + expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) + if err2 == nil { + return hmac.Equal(inputMAC, expectedMAC) + } + } + return false +} diff --git a/pkg/encryption/utils_test.go b/pkg/encryption/utils_test.go new file mode 100644 index 00000000..15bc83fe --- /dev/null +++ b/pkg/encryption/utils_test.go @@ -0,0 +1,100 @@ +package encryption + +import ( + "crypto/rand" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSecretBytesEncoded(t *testing.T) { + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + // We test both padded & raw Base64 to ensure we handle both + // potential user input routes for Base64 + base64Padded := base64.URLEncoding.EncodeToString(secret) + sb := SecretBytes(base64Padded) + assert.Equal(t, secret, sb) + assert.Equal(t, len(sb), secretSize) + + base64Raw := base64.RawURLEncoding.EncodeToString(secret) + sb = SecretBytes(base64Raw) + assert.Equal(t, secret, sb) + assert.Equal(t, len(sb), secretSize) + }) + } +} + +// A string that isn't intended as Base64 and still decodes (but to unintended length) +// will return the original secret as bytes +func TestSecretBytesEncodedWrongSize(t *testing.T) { + for _, secretSize := range []int{15, 20, 28, 33, 44} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + // We test both padded & raw Base64 to ensure we handle both + // potential user input routes for Base64 + base64Padded := base64.URLEncoding.EncodeToString(secret) + sb := SecretBytes(base64Padded) + assert.NotEqual(t, secret, sb) + assert.NotEqual(t, len(sb), secretSize) + // The given secret is returned as []byte + assert.Equal(t, base64Padded, string(sb)) + + base64Raw := base64.RawURLEncoding.EncodeToString(secret) + sb = SecretBytes(base64Raw) + assert.NotEqual(t, secret, sb) + assert.NotEqual(t, len(sb), secretSize) + // The given secret is returned as []byte + assert.Equal(t, base64Raw, string(sb)) + }) + } +} + +func TestSecretBytesNonBase64(t *testing.T) { + trailer := "equals==========" + assert.Equal(t, trailer, string(SecretBytes(trailer))) + + raw16 := "asdflkjhqwer)(*&" + sb16 := SecretBytes(raw16) + assert.Equal(t, raw16, string(sb16)) + assert.Equal(t, 16, len(sb16)) + + raw24 := "asdflkjhqwer)(*&CJEN#$%^" + sb24 := SecretBytes(raw24) + assert.Equal(t, raw24, string(sb24)) + assert.Equal(t, 24, len(sb24)) + + raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&" + sb32 := SecretBytes(raw32) + assert.Equal(t, raw32, string(sb32)) + assert.Equal(t, 32, len(sb32)) +} + +func TestSignAndValidate(t *testing.T) { + seed := "0123456789abcdef" + key := "cookie-name" + value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded")) + epoch := "123456789" + + sha256sig := cookieSignature(sha256.New, seed, key, value, epoch) + sha1sig := cookieSignature(sha1.New, seed, key, value, epoch) + + assert.True(t, checkSignature(sha256sig, seed, key, value, epoch)) + // This should be switched to False after fully deprecating SHA1 + assert.True(t, checkSignature(sha1sig, seed, key, value, epoch)) + + assert.False(t, checkSignature(sha256sig, seed, key, "tampered", epoch)) + assert.False(t, checkSignature(sha1sig, seed, key, "tampered", epoch)) +} diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 1b88e027..62f6b348 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -28,7 +28,7 @@ var _ sessions.SessionStore = &SessionStore{} // interface that stores sessions in client side cookies type SessionStore struct { CookieOptions *options.CookieOptions - CookieCipher *encryption.Cipher + CookieCipher encryption.Cipher } // Save takes a sessions.SessionState and stores the information from it @@ -59,7 +59,7 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { return nil, errors.New("cookie signature not valid") } - session, err := sessionFromCookie(val, s.CookieCipher) + session, err := sessionFromCookie(string(val), s.CookieCipher) if err != nil { return nil, err } @@ -84,12 +84,12 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { } // cookieForSession serializes a session state for storage in a cookie -func cookieForSession(s *sessions.SessionState, c *encryption.Cipher) (string, error) { +func cookieForSession(s *sessions.SessionState, c encryption.Cipher) (string, error) { return s.EncodeSessionState(c) } // sessionFromCookie deserializes a session from a cookie value -func sessionFromCookie(v string, c *encryption.Cipher) (s *sessions.SessionState, err error) { +func sessionFromCookie(v string, c encryption.Cipher) (s *sessions.SessionState, err error) { return sessions.DecodeSessionState(v, c) } @@ -104,7 +104,7 @@ func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Reques // authentication details func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie { if value != "" { - value = encryption.SignedValue(s.CookieOptions.Secret, s.CookieOptions.Name, value, now) + value = encryption.SignedValue(s.CookieOptions.Secret, s.CookieOptions.Name, []byte(value), now) } c := s.makeCookie(req, s.CookieOptions.Name, value, s.CookieOptions.Expire, now) if len(c.Value) > 4096-len(s.CookieOptions.Name) { diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index 472ac3c9..2d0fd9b0 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -32,7 +32,7 @@ type TicketData struct { // SessionStore is an implementation of the sessions.SessionStore // interface that stores sessions in redis type SessionStore struct { - CookieCipher *encryption.Cipher + CookieCipher encryption.Cipher CookieOptions *options.CookieOptions Client Client } @@ -175,7 +175,7 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro return nil, fmt.Errorf("cookie signature not valid") } ctx := req.Context() - session, err := store.loadSessionFromString(ctx, val) + session, err := store.loadSessionFromString(ctx, string(val)) if err != nil { return nil, fmt.Errorf("error loading session: %s", err) } @@ -237,7 +237,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro // We only return an error if we had an issue with redis // If there's an issue decoding the ticket, ignore it - ticket, _ := decodeTicket(store.CookieOptions.Name, val) + ticket, _ := decodeTicket(store.CookieOptions.Name, string(val)) if ticket != nil { ctx := req.Context() err := store.Client.Del(ctx, ticket.asHandle(store.CookieOptions.Name)) @@ -251,7 +251,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro // makeCookie makes a cookie, signing the value if present func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie { if value != "" { - value = encryption.SignedValue(store.CookieOptions.Secret, store.CookieOptions.Name, value, now) + value = encryption.SignedValue(store.CookieOptions.Secret, store.CookieOptions.Name, []byte(value), now) } return cookies.MakeCookieFromOptions( req, @@ -302,7 +302,7 @@ func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, e } // Valid cookie, decode the ticket - ticket, err := decodeTicket(store.CookieOptions.Name, val) + ticket, err := decodeTicket(store.CookieOptions.Name, string(val)) if err != nil { // If we can't decode the ticket we have to create a new one return newTicket() diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 60a86cef..1a60aa0d 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -170,7 +170,7 @@ var _ = Describe("NewSessionStore", func() { BeforeEach(func() { By("Using a valid cookie with a different providers session encoding") broken := "BrokenSessionFromADifferentSessionImplementation" - value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, broken, time.Now()) + value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, []byte(broken), time.Now()) cookie := cookiesapi.MakeCookieFromOptions(request, cookieOpts.Name, value, cookieOpts, cookieOpts.Expire, time.Now()) request.AddCookie(cookie) @@ -367,7 +367,7 @@ var _ = Describe("NewSessionStore", func() { _, err := rand.Read(secret) Expect(err).ToNot(HaveOccurred()) cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret) - cipher, err := encryption.NewCipher(encryption.SecretBytes(cookieOpts.Secret)) + cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) Expect(err).ToNot(HaveOccurred()) Expect(cipher).ToNot(BeNil()) opts.Cipher = cipher diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 2e028677..b22882d0 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -38,7 +38,7 @@ func Validate(o *options.Options) error { msgs := make([]string, 0) - var cipher *encryption.Cipher + var cipher encryption.Cipher if o.Cookie.Secret == "" { msgs = append(msgs, "missing setting: cookie-secret") } else { @@ -62,7 +62,7 @@ func Validate(o *options.Options) error { len(encryption.SecretBytes(o.Cookie.Secret)), suffix)) } else { var err error - cipher, err = encryption.NewCipher(encryption.SecretBytes(o.Cookie.Secret)) + cipher, err = encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(o.Cookie.Secret)) if err != nil { msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err)) }