Merge pull request #539 from grnhse/encryption-efficiency-improvements
Encryption efficiency improvements
This commit is contained in:
		
						commit
						a197a17bc3
					
				| 
						 | 
					@ -55,6 +55,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Changes since v5.1.1
 | 
					## Changes since v5.1.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves)
 | 
				
			||||||
- [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed)
 | 
					- [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed)
 | 
				
			||||||
- [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed)
 | 
					- [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed)
 | 
				
			||||||
- [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer)
 | 
					- [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,7 +5,7 @@ import "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
 | 
				
			||||||
// SessionOptions contains configuration options for the SessionStore providers.
 | 
					// SessionOptions contains configuration options for the SessionStore providers.
 | 
				
			||||||
type SessionOptions struct {
 | 
					type SessionOptions struct {
 | 
				
			||||||
	Type   string            `flag:"session-store-type" cfg:"session_store_type"`
 | 
						Type   string            `flag:"session-store-type" cfg:"session_store_type"`
 | 
				
			||||||
	Cipher *encryption.Cipher `cfg:",internal"`
 | 
						Cipher encryption.Cipher `cfg:",internal"`
 | 
				
			||||||
	Redis  RedisStoreOptions `cfg:",squash"`
 | 
						Redis  RedisStoreOptions `cfg:",squash"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,7 +60,7 @@ func (s *SessionState) String() string {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// EncodeSessionState returns string representation of the current session
 | 
					// EncodeSessionState returns string representation of the current session
 | 
				
			||||||
func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) {
 | 
					func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) {
 | 
				
			||||||
	var ss SessionState
 | 
						var ss SessionState
 | 
				
			||||||
	if c == nil {
 | 
						if c == nil {
 | 
				
			||||||
		// Store only Email and User when cipher is unavailable
 | 
							// Store only Email and User when cipher is unavailable
 | 
				
			||||||
| 
						 | 
					@ -77,7 +77,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
 | 
				
			||||||
			&ss.IDToken,
 | 
								&ss.IDToken,
 | 
				
			||||||
			&ss.RefreshToken,
 | 
								&ss.RefreshToken,
 | 
				
			||||||
		} {
 | 
							} {
 | 
				
			||||||
			err := c.EncryptInto(s)
 | 
								err := into(s, c.Encrypt)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return "", err
 | 
									return "", err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -89,7 +89,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// DecodeSessionState decodes the session cookie string into a SessionState
 | 
					// DecodeSessionState decodes the session cookie string into a SessionState
 | 
				
			||||||
func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
 | 
					func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) {
 | 
				
			||||||
	var ss SessionState
 | 
						var ss SessionState
 | 
				
			||||||
	err := json.Unmarshal([]byte(v), &ss)
 | 
						err := json.Unmarshal([]byte(v), &ss)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -104,25 +104,19 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
 | 
				
			||||||
			PreferredUsername: ss.PreferredUsername,
 | 
								PreferredUsername: ss.PreferredUsername,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		// Backward compatibility with using unencrypted Email
 | 
							// Backward compatibility with using unencrypted Email or User
 | 
				
			||||||
		if ss.Email != "" {
 | 
							// Decryption errors will leave original string
 | 
				
			||||||
			decryptedEmail, errEmail := c.Decrypt(ss.Email)
 | 
							err = into(&ss.Email, c.Decrypt)
 | 
				
			||||||
			if errEmail == nil {
 | 
							if err == nil {
 | 
				
			||||||
				if !utf8.ValidString(decryptedEmail) {
 | 
								if !utf8.ValidString(ss.Email) {
 | 
				
			||||||
				return nil, errors.New("invalid value for decrypted email")
 | 
									return nil, errors.New("invalid value for decrypted email")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
				ss.Email = decryptedEmail
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		}
 | 
							err = into(&ss.User, c.Decrypt)
 | 
				
			||||||
		// Backward compatibility with using unencrypted User
 | 
							if err == nil {
 | 
				
			||||||
		if ss.User != "" {
 | 
								if !utf8.ValidString(ss.User) {
 | 
				
			||||||
			decryptedUser, errUser := c.Decrypt(ss.User)
 | 
					 | 
				
			||||||
			if errUser == nil {
 | 
					 | 
				
			||||||
				if !utf8.ValidString(decryptedUser) {
 | 
					 | 
				
			||||||
				return nil, errors.New("invalid value for decrypted user")
 | 
									return nil, errors.New("invalid value for decrypted user")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
				ss.User = decryptedUser
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for _, s := range []*string{
 | 
							for _, s := range []*string{
 | 
				
			||||||
| 
						 | 
					@ -131,7 +125,7 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
 | 
				
			||||||
			&ss.IDToken,
 | 
								&ss.IDToken,
 | 
				
			||||||
			&ss.RefreshToken,
 | 
								&ss.RefreshToken,
 | 
				
			||||||
		} {
 | 
							} {
 | 
				
			||||||
			err := c.DecryptInto(s)
 | 
								err := into(s, c.Decrypt)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return nil, err
 | 
									return nil, err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -139,3 +133,20 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &ss, nil
 | 
						return &ss, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// codecFunc is a function that takes a []byte and encodes/decodes it
 | 
				
			||||||
 | 
					type codecFunc func([]byte) ([]byte, error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func into(s *string, f codecFunc) error {
 | 
				
			||||||
 | 
						// Do not encrypt/decrypt nil or empty strings
 | 
				
			||||||
 | 
						if s == nil || *s == "" {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						d, err := f([]byte(*s))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						*s = string(d)
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,11 +1,13 @@
 | 
				
			||||||
package sessions_test
 | 
					package sessions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						mathrand "math/rand"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
 | 
					 | 
				
			||||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
 | 
						"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
 | 
				
			||||||
	"github.com/stretchr/testify/assert"
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -17,12 +19,16 @@ func timePtr(t time.Time) *time.Time {
 | 
				
			||||||
	return &t
 | 
						return &t
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newTestCipher(secret []byte) (encryption.Cipher, error) {
 | 
				
			||||||
 | 
						return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSessionStateSerialization(t *testing.T) {
 | 
					func TestSessionStateSerialization(t *testing.T) {
 | 
				
			||||||
	c, err := encryption.NewCipher([]byte(secret))
 | 
						c, err := newTestCipher([]byte(secret))
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	c2, err := encryption.NewCipher([]byte(altSecret))
 | 
						c2, err := newTestCipher([]byte(altSecret))
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	s := &sessions.SessionState{
 | 
						s := &SessionState{
 | 
				
			||||||
		Email:             "user@domain.com",
 | 
							Email:             "user@domain.com",
 | 
				
			||||||
		PreferredUsername: "user",
 | 
							PreferredUsername: "user",
 | 
				
			||||||
		AccessToken:       "token1234",
 | 
							AccessToken:       "token1234",
 | 
				
			||||||
| 
						 | 
					@ -34,7 +40,7 @@ func TestSessionStateSerialization(t *testing.T) {
 | 
				
			||||||
	encoded, err := s.EncodeSessionState(c)
 | 
						encoded, err := s.EncodeSessionState(c)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ss, err := sessions.DecodeSessionState(encoded, c)
 | 
						ss, err := DecodeSessionState(encoded, c)
 | 
				
			||||||
	t.Logf("%#v", ss)
 | 
						t.Logf("%#v", ss)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", ss.User)
 | 
						assert.Equal(t, "", ss.User)
 | 
				
			||||||
| 
						 | 
					@ -47,17 +53,17 @@ func TestSessionStateSerialization(t *testing.T) {
 | 
				
			||||||
	assert.Equal(t, s.RefreshToken, ss.RefreshToken)
 | 
						assert.Equal(t, s.RefreshToken, ss.RefreshToken)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 
						// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 
				
			||||||
	ss, err = sessions.DecodeSessionState(encoded, c2)
 | 
						ss, err = DecodeSessionState(encoded, c2)
 | 
				
			||||||
	t.Logf("%#v", ss)
 | 
						t.Logf("%#v", ss)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSessionStateSerializationWithUser(t *testing.T) {
 | 
					func TestSessionStateSerializationWithUser(t *testing.T) {
 | 
				
			||||||
	c, err := encryption.NewCipher([]byte(secret))
 | 
						c, err := newTestCipher([]byte(secret))
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	c2, err := encryption.NewCipher([]byte(altSecret))
 | 
						c2, err := newTestCipher([]byte(altSecret))
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	s := &sessions.SessionState{
 | 
						s := &SessionState{
 | 
				
			||||||
		User:              "just-user",
 | 
							User:              "just-user",
 | 
				
			||||||
		PreferredUsername: "ju",
 | 
							PreferredUsername: "ju",
 | 
				
			||||||
		Email:             "user@domain.com",
 | 
							Email:             "user@domain.com",
 | 
				
			||||||
| 
						 | 
					@ -69,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
 | 
				
			||||||
	encoded, err := s.EncodeSessionState(c)
 | 
						encoded, err := s.EncodeSessionState(c)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ss, err := sessions.DecodeSessionState(encoded, c)
 | 
						ss, err := DecodeSessionState(encoded, c)
 | 
				
			||||||
	t.Logf("%#v", ss)
 | 
						t.Logf("%#v", ss)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, s.User, ss.User)
 | 
						assert.Equal(t, s.User, ss.User)
 | 
				
			||||||
| 
						 | 
					@ -81,13 +87,13 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
 | 
				
			||||||
	assert.Equal(t, s.RefreshToken, ss.RefreshToken)
 | 
						assert.Equal(t, s.RefreshToken, ss.RefreshToken)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 
						// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 
				
			||||||
	ss, err = sessions.DecodeSessionState(encoded, c2)
 | 
						ss, err = DecodeSessionState(encoded, c2)
 | 
				
			||||||
	t.Logf("%#v", ss)
 | 
						t.Logf("%#v", ss)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
 | 
					func TestSessionStateSerializationNoCipher(t *testing.T) {
 | 
				
			||||||
	s := &sessions.SessionState{
 | 
						s := &SessionState{
 | 
				
			||||||
		Email:             "user@domain.com",
 | 
							Email:             "user@domain.com",
 | 
				
			||||||
		PreferredUsername: "user",
 | 
							PreferredUsername: "user",
 | 
				
			||||||
		AccessToken:       "token1234",
 | 
							AccessToken:       "token1234",
 | 
				
			||||||
| 
						 | 
					@ -99,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// only email should have been serialized
 | 
						// only email should have been serialized
 | 
				
			||||||
	ss, err := sessions.DecodeSessionState(encoded, nil)
 | 
						ss, err := DecodeSessionState(encoded, nil)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", ss.User)
 | 
						assert.Equal(t, "", ss.User)
 | 
				
			||||||
	assert.Equal(t, s.Email, ss.Email)
 | 
						assert.Equal(t, s.Email, ss.Email)
 | 
				
			||||||
| 
						 | 
					@ -109,7 +115,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
 | 
					func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
 | 
				
			||||||
	s := &sessions.SessionState{
 | 
						s := &SessionState{
 | 
				
			||||||
		User:              "just-user",
 | 
							User:              "just-user",
 | 
				
			||||||
		Email:             "user@domain.com",
 | 
							Email:             "user@domain.com",
 | 
				
			||||||
		PreferredUsername: "user",
 | 
							PreferredUsername: "user",
 | 
				
			||||||
| 
						 | 
					@ -122,7 +128,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// only email should have been serialized
 | 
						// only email should have been serialized
 | 
				
			||||||
	ss, err := sessions.DecodeSessionState(encoded, nil)
 | 
						ss, err := DecodeSessionState(encoded, nil)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, s.User, ss.User)
 | 
						assert.Equal(t, s.User, ss.User)
 | 
				
			||||||
	assert.Equal(t, s.Email, ss.Email)
 | 
						assert.Equal(t, s.Email, ss.Email)
 | 
				
			||||||
| 
						 | 
					@ -132,20 +138,20 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestExpired(t *testing.T) {
 | 
					func TestExpired(t *testing.T) {
 | 
				
			||||||
	s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
 | 
						s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
 | 
				
			||||||
	assert.Equal(t, true, s.IsExpired())
 | 
						assert.Equal(t, true, s.IsExpired())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
 | 
						s = &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
 | 
				
			||||||
	assert.Equal(t, false, s.IsExpired())
 | 
						assert.Equal(t, false, s.IsExpired())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s = &sessions.SessionState{}
 | 
						s = &SessionState{}
 | 
				
			||||||
	assert.Equal(t, false, s.IsExpired())
 | 
						assert.Equal(t, false, s.IsExpired())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type testCase struct {
 | 
					type testCase struct {
 | 
				
			||||||
	sessions.SessionState
 | 
						SessionState
 | 
				
			||||||
	Encoded string
 | 
						Encoded string
 | 
				
			||||||
	Cipher  *encryption.Cipher
 | 
						Cipher  encryption.Cipher
 | 
				
			||||||
	Error   bool
 | 
						Error   bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -159,14 +165,14 @@ func TestEncodeSessionState(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	testCases := []testCase{
 | 
						testCases := []testCase{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				Email: "user@domain.com",
 | 
									Email: "user@domain.com",
 | 
				
			||||||
				User:  "just-user",
 | 
									User:  "just-user",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
 | 
								Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				Email:        "user@domain.com",
 | 
									Email:        "user@domain.com",
 | 
				
			||||||
				User:         "just-user",
 | 
									User:         "just-user",
 | 
				
			||||||
				AccessToken:  "token1234",
 | 
									AccessToken:  "token1234",
 | 
				
			||||||
| 
						 | 
					@ -181,7 +187,7 @@ func TestEncodeSessionState(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i, tc := range testCases {
 | 
						for i, tc := range testCases {
 | 
				
			||||||
		encoded, err := tc.EncodeSessionState(tc.Cipher)
 | 
							encoded, err := tc.EncodeSessionState(tc.Cipher)
 | 
				
			||||||
		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
 | 
							t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
 | 
				
			||||||
		if tc.Error {
 | 
							if tc.Error {
 | 
				
			||||||
			assert.Error(t, err)
 | 
								assert.Error(t, err)
 | 
				
			||||||
			assert.Empty(t, encoded)
 | 
								assert.Empty(t, encoded)
 | 
				
			||||||
| 
						 | 
					@ -201,39 +207,39 @@ func TestDecodeSessionState(t *testing.T) {
 | 
				
			||||||
	eJSON, _ := e.MarshalJSON()
 | 
						eJSON, _ := e.MarshalJSON()
 | 
				
			||||||
	eString := string(eJSON)
 | 
						eString := string(eJSON)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c, err := encryption.NewCipher([]byte(secret))
 | 
						c, err := newTestCipher([]byte(secret))
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	testCases := []testCase{
 | 
						testCases := []testCase{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				Email: "user@domain.com",
 | 
									Email: "user@domain.com",
 | 
				
			||||||
				User:  "just-user",
 | 
									User:  "just-user",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
 | 
								Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				Email: "user@domain.com",
 | 
									Email: "user@domain.com",
 | 
				
			||||||
				User:  "",
 | 
									User:  "",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			Encoded: `{"Email":"user@domain.com"}`,
 | 
								Encoded: `{"Email":"user@domain.com"}`,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				User: "just-user",
 | 
									User: "just-user",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			Encoded: `{"User":"just-user"}`,
 | 
								Encoded: `{"User":"just-user"}`,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				Email: "user@domain.com",
 | 
									Email: "user@domain.com",
 | 
				
			||||||
				User:  "just-user",
 | 
									User:  "just-user",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
 | 
								Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				Email:        "user@domain.com",
 | 
									Email:        "user@domain.com",
 | 
				
			||||||
				User:         "just-user",
 | 
									User:         "just-user",
 | 
				
			||||||
				AccessToken:  "token1234",
 | 
									AccessToken:  "token1234",
 | 
				
			||||||
| 
						 | 
					@ -246,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
 | 
				
			||||||
			Cipher:  c,
 | 
								Cipher:  c,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				Email: "user@domain.com",
 | 
									Email: "user@domain.com",
 | 
				
			||||||
				User:  "just-user",
 | 
									User:  "just-user",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
| 
						 | 
					@ -264,7 +270,7 @@ func TestDecodeSessionState(t *testing.T) {
 | 
				
			||||||
			Error:   true,
 | 
								Error:   true,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			SessionState: sessions.SessionState{
 | 
								SessionState: SessionState{
 | 
				
			||||||
				Email: "user@domain.com",
 | 
									Email: "user@domain.com",
 | 
				
			||||||
				User:  "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user
 | 
									User:  "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
| 
						 | 
					@ -274,8 +280,8 @@ func TestDecodeSessionState(t *testing.T) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i, tc := range testCases {
 | 
						for i, tc := range testCases {
 | 
				
			||||||
		ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
 | 
							ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
 | 
				
			||||||
		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
 | 
							t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
 | 
				
			||||||
		if tc.Error {
 | 
							if tc.Error {
 | 
				
			||||||
			assert.Error(t, err)
 | 
								assert.Error(t, err)
 | 
				
			||||||
			assert.Nil(t, ss)
 | 
								assert.Nil(t, ss)
 | 
				
			||||||
| 
						 | 
					@ -297,7 +303,7 @@ func TestDecodeSessionState(t *testing.T) {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSessionStateAge(t *testing.T) {
 | 
					func TestSessionStateAge(t *testing.T) {
 | 
				
			||||||
	ss := &sessions.SessionState{}
 | 
						ss := &SessionState{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Created at unset so should be 0
 | 
						// Created at unset so should be 0
 | 
				
			||||||
	assert.Equal(t, time.Duration(0), ss.Age())
 | 
						assert.Equal(t, time.Duration(0), ss.Age())
 | 
				
			||||||
| 
						 | 
					@ -306,3 +312,44 @@ func TestSessionStateAge(t *testing.T) {
 | 
				
			||||||
	ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour))
 | 
						ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour))
 | 
				
			||||||
	assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
 | 
						assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestIntoEncryptAndIntoDecrypt(t *testing.T) {
 | 
				
			||||||
 | 
						const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Test all 3 valid AES sizes
 | 
				
			||||||
 | 
						for _, secretSize := range []int{16, 24, 32} {
 | 
				
			||||||
 | 
							t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
 | 
				
			||||||
 | 
								secret := make([]byte, secretSize)
 | 
				
			||||||
 | 
								_, err := io.ReadFull(rand.Reader, secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								c, err := newTestCipher(secret)
 | 
				
			||||||
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Check no errors with empty or nil strings
 | 
				
			||||||
 | 
								empty := ""
 | 
				
			||||||
 | 
								assert.Equal(t, nil, into(&empty, c.Encrypt))
 | 
				
			||||||
 | 
								assert.Equal(t, nil, into(&empty, c.Decrypt))
 | 
				
			||||||
 | 
								assert.Equal(t, nil, into(nil, c.Encrypt))
 | 
				
			||||||
 | 
								assert.Equal(t, nil, into(nil, c.Decrypt))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Test various sizes tokens might be
 | 
				
			||||||
 | 
								for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
 | 
				
			||||||
 | 
									t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
 | 
				
			||||||
 | 
										b := make([]byte, dataSize)
 | 
				
			||||||
 | 
										for i := range b {
 | 
				
			||||||
 | 
											b[i] = charset[mathrand.Intn(len(charset))]
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										data := string(b)
 | 
				
			||||||
 | 
										originalData := data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										assert.Equal(t, nil, into(&data, c.Encrypt))
 | 
				
			||||||
 | 
										assert.NotEqual(t, originalData, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										assert.Equal(t, nil, into(&data, c.Decrypt))
 | 
				
			||||||
 | 
										assert.Equal(t, originalData, data)
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,183 +3,134 @@ package encryption
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"crypto/aes"
 | 
						"crypto/aes"
 | 
				
			||||||
	"crypto/cipher"
 | 
						"crypto/cipher"
 | 
				
			||||||
	"crypto/hmac"
 | 
					 | 
				
			||||||
	"crypto/rand"
 | 
						"crypto/rand"
 | 
				
			||||||
	"crypto/sha1"
 | 
					 | 
				
			||||||
	"crypto/sha256"
 | 
					 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"hash"
 | 
					 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
 | 
					// Cipher provides methods to encrypt and decrypt
 | 
				
			||||||
func SecretBytes(secret string) []byte {
 | 
					type Cipher interface {
 | 
				
			||||||
	b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "="))
 | 
						Encrypt(value []byte) ([]byte, error)
 | 
				
			||||||
	if err == nil {
 | 
						Decrypt(ciphertext []byte) ([]byte, error)
 | 
				
			||||||
		// Only return decoded form if a valid AES length
 | 
					 | 
				
			||||||
		// Don't want unintentional decoding resulting in invalid lengths confusing a user
 | 
					 | 
				
			||||||
		// that thought they used a 16, 24, 32 length string
 | 
					 | 
				
			||||||
		for _, i := range []int{16, 24, 32} {
 | 
					 | 
				
			||||||
			if len(b) == i {
 | 
					 | 
				
			||||||
				return b
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// If decoding didn't work or resulted in non-AES compliant length,
 | 
					 | 
				
			||||||
	// assume the raw string was the intended secret
 | 
					 | 
				
			||||||
	return []byte(secret)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
 | 
					type base64Cipher struct {
 | 
				
			||||||
// additionally, the 'value' is encrypted so it's opaque to the browser
 | 
						Cipher Cipher
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Validate ensures a cookie is properly signed
 | 
					// NewBase64Cipher returns a new AES Cipher for encrypting cookie values
 | 
				
			||||||
func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) {
 | 
					// and wrapping them in Base64 -- Supports Legacy encryption scheme
 | 
				
			||||||
	// value, timestamp, sig
 | 
					func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Cipher, error) {
 | 
				
			||||||
	parts := strings.Split(cookie.Value, "|")
 | 
						c, err := initCipher(secret)
 | 
				
			||||||
	if len(parts) != 3 {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if checkSignature(parts[2], seed, cookie.Name, parts[0], parts[1]) {
 | 
					 | 
				
			||||||
		ts, err := strconv.Atoi(parts[1])
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
			return
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
		// The expiration timestamp set when the cookie was created
 | 
						return &base64Cipher{Cipher: c}, nil
 | 
				
			||||||
		// isn't sent back by the browser. Hence, we check whether the
 | 
					 | 
				
			||||||
		// creation timestamp stored in the cookie falls within the
 | 
					 | 
				
			||||||
		// window defined by (Now()-expiration, Now()].
 | 
					 | 
				
			||||||
		t = time.Unix(int64(ts), 0)
 | 
					 | 
				
			||||||
		if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
 | 
					 | 
				
			||||||
			// it's a valid cookie. now get the contents
 | 
					 | 
				
			||||||
			rawValue, err := base64.URLEncoding.DecodeString(parts[0])
 | 
					 | 
				
			||||||
			if err == nil {
 | 
					 | 
				
			||||||
				value = string(rawValue)
 | 
					 | 
				
			||||||
				ok = true
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SignedValue returns a cookie that is signed and can later be checked with Validate
 | 
					// Encrypt encrypts a value with the embedded Cipher & Base64 encodes it
 | 
				
			||||||
func SignedValue(seed string, key string, value string, now time.Time) string {
 | 
					func (c *base64Cipher) Encrypt(value []byte) ([]byte, error) {
 | 
				
			||||||
	encodedValue := base64.URLEncoding.EncodeToString([]byte(value))
 | 
						encrypted, err := c.Cipher.Encrypt(value)
 | 
				
			||||||
	timeStr := fmt.Sprintf("%d", now.Unix())
 | 
						if err != nil {
 | 
				
			||||||
	sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr)
 | 
							return nil, err
 | 
				
			||||||
	cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
 | 
					 | 
				
			||||||
	return cookieVal
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func cookieSignature(signer func() hash.Hash, args ...string) string {
 | 
					 | 
				
			||||||
	h := hmac.New(signer, []byte(args[0]))
 | 
					 | 
				
			||||||
	for _, arg := range args[1:] {
 | 
					 | 
				
			||||||
		h.Write([]byte(arg))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var b []byte
 | 
					 | 
				
			||||||
	b = h.Sum(b)
 | 
					 | 
				
			||||||
	return base64.URLEncoding.EncodeToString(b)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func checkSignature(signature string, args ...string) bool {
 | 
					 | 
				
			||||||
	checkSig := cookieSignature(sha256.New, args...)
 | 
					 | 
				
			||||||
	if checkHmac(signature, checkSig) {
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO: After appropriate rollout window, remove support for SHA1
 | 
						return []byte(base64.StdEncoding.EncodeToString(encrypted)), nil
 | 
				
			||||||
	legacySig := cookieSignature(sha1.New, args...)
 | 
					 | 
				
			||||||
	return checkHmac(signature, legacySig)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func checkHmac(input, expected string) bool {
 | 
					// Decrypt Base64 decodes a value & decrypts it with the embedded Cipher
 | 
				
			||||||
	inputMAC, err1 := base64.URLEncoding.DecodeString(input)
 | 
					func (c *base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) {
 | 
				
			||||||
	if err1 == nil {
 | 
						encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext))
 | 
				
			||||||
		expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
 | 
						if err != nil {
 | 
				
			||||||
		if err2 == nil {
 | 
							return nil, fmt.Errorf("failed to base64 decode value %s", err)
 | 
				
			||||||
			return hmac.Equal(inputMAC, expectedMAC)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	}
 | 
					
 | 
				
			||||||
	return false
 | 
						return c.Cipher.Decrypt(encrypted)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Cipher provides methods to encrypt and decrypt cookie values
 | 
					type cfbCipher struct {
 | 
				
			||||||
type Cipher struct {
 | 
					 | 
				
			||||||
	cipher.Block
 | 
						cipher.Block
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewCipher returns a new aes Cipher for encrypting cookie values
 | 
					// NewCFBCipher returns a new AES CFB Cipher
 | 
				
			||||||
func NewCipher(secret []byte) (*Cipher, error) {
 | 
					func NewCFBCipher(secret []byte) (Cipher, error) {
 | 
				
			||||||
	c, err := aes.NewCipher(secret)
 | 
						c, err := aes.NewCipher(secret)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &Cipher{Block: c}, err
 | 
						return &cfbCipher{Block: c}, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Encrypt a value for use in a cookie
 | 
					// Encrypt with AES CFB
 | 
				
			||||||
func (c *Cipher) Encrypt(value string) (string, error) {
 | 
					func (c *cfbCipher) Encrypt(value []byte) ([]byte, error) {
 | 
				
			||||||
	ciphertext := make([]byte, aes.BlockSize+len(value))
 | 
						ciphertext := make([]byte, aes.BlockSize+len(value))
 | 
				
			||||||
	iv := ciphertext[:aes.BlockSize]
 | 
						iv := ciphertext[:aes.BlockSize]
 | 
				
			||||||
	if _, err := io.ReadFull(rand.Reader, iv); err != nil {
 | 
						if _, err := io.ReadFull(rand.Reader, iv); err != nil {
 | 
				
			||||||
		return "", fmt.Errorf("failed to create initialization vector %s", err)
 | 
							return nil, fmt.Errorf("failed to create initialization vector %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stream := cipher.NewCFBEncrypter(c.Block, iv)
 | 
						stream := cipher.NewCFBEncrypter(c.Block, iv)
 | 
				
			||||||
	stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value))
 | 
						stream.XORKeyStream(ciphertext[aes.BlockSize:], value)
 | 
				
			||||||
	return base64.StdEncoding.EncodeToString(ciphertext), nil
 | 
						return ciphertext, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Decrypt a value from a cookie to it's original string
 | 
					// Decrypt an AES CFB ciphertext
 | 
				
			||||||
func (c *Cipher) Decrypt(s string) (string, error) {
 | 
					func (c *cfbCipher) Decrypt(ciphertext []byte) ([]byte, error) {
 | 
				
			||||||
	encrypted, err := base64.StdEncoding.DecodeString(s)
 | 
						if len(ciphertext) < aes.BlockSize {
 | 
				
			||||||
	if err != nil {
 | 
							return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext))
 | 
				
			||||||
		return "", fmt.Errorf("failed to decrypt cookie value %s", err)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(encrypted) < aes.BlockSize {
 | 
						iv, ciphertext := ciphertext[:aes.BlockSize], ciphertext[aes.BlockSize:]
 | 
				
			||||||
		return "", fmt.Errorf("encrypted cookie value should be "+
 | 
						plaintext := make([]byte, len(ciphertext))
 | 
				
			||||||
			"at least %d bytes, but is only %d bytes",
 | 
					 | 
				
			||||||
			aes.BlockSize, len(encrypted))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	iv := encrypted[:aes.BlockSize]
 | 
					 | 
				
			||||||
	encrypted = encrypted[aes.BlockSize:]
 | 
					 | 
				
			||||||
	stream := cipher.NewCFBDecrypter(c.Block, iv)
 | 
						stream := cipher.NewCFBDecrypter(c.Block, iv)
 | 
				
			||||||
	stream.XORKeyStream(encrypted, encrypted)
 | 
						stream.XORKeyStream(plaintext, ciphertext)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return string(encrypted), nil
 | 
						return plaintext, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// EncryptInto encrypts the value and stores it back in the string pointer
 | 
					type gcmCipher struct {
 | 
				
			||||||
func (c *Cipher) EncryptInto(s *string) error {
 | 
						cipher.Block
 | 
				
			||||||
	return into(c.Encrypt, s)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// DecryptInto decrypts the value and stores it back in the string pointer
 | 
					// NewGCMCipher returns a new AES GCM Cipher
 | 
				
			||||||
func (c *Cipher) DecryptInto(s *string) error {
 | 
					func NewGCMCipher(secret []byte) (Cipher, error) {
 | 
				
			||||||
	return into(c.Decrypt, s)
 | 
						c, err := aes.NewCipher(secret)
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// codecFunc is a function that takes a string and encodes/decodes it
 | 
					 | 
				
			||||||
type codecFunc func(string) (string, error)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func into(f codecFunc, s *string) error {
 | 
					 | 
				
			||||||
	// Do not encrypt/decrypt nil or empty strings
 | 
					 | 
				
			||||||
	if s == nil || *s == "" {
 | 
					 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	d, err := f(*s)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	*s = d
 | 
						return &gcmCipher{Block: c}, err
 | 
				
			||||||
	return nil
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Encrypt with AES GCM on raw bytes
 | 
				
			||||||
 | 
					func (c *gcmCipher) Encrypt(value []byte) ([]byte, error) {
 | 
				
			||||||
 | 
						gcm, err := cipher.NewGCM(c.Block)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						nonce := make([]byte, gcm.NonceSize())
 | 
				
			||||||
 | 
						if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Using nonce as Seal's dst argument results in it being the first
 | 
				
			||||||
 | 
						// chunk of bytes in the ciphertext. Decrypt retrieves the nonce/IV from this.
 | 
				
			||||||
 | 
						ciphertext := gcm.Seal(nonce, nonce, value, nil)
 | 
				
			||||||
 | 
						return ciphertext, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Decrypt an AES GCM ciphertext
 | 
				
			||||||
 | 
					func (c *gcmCipher) Decrypt(ciphertext []byte) ([]byte, error) {
 | 
				
			||||||
 | 
						gcm, err := cipher.NewGCM(c.Block)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						nonceSize := gcm.NonceSize()
 | 
				
			||||||
 | 
						nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return plaintext, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,8 +2,6 @@ package encryption
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"crypto/rand"
 | 
						"crypto/rand"
 | 
				
			||||||
	"crypto/sha1"
 | 
					 | 
				
			||||||
	"crypto/sha256"
 | 
					 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
| 
						 | 
					@ -12,107 +10,20 @@ import (
 | 
				
			||||||
	"github.com/stretchr/testify/assert"
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSecretBytesEncoded(t *testing.T) {
 | 
					 | 
				
			||||||
	for _, secretSize := range []int{16, 24, 32} {
 | 
					 | 
				
			||||||
		t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
 | 
					 | 
				
			||||||
			secret := make([]byte, secretSize)
 | 
					 | 
				
			||||||
			_, err := io.ReadFull(rand.Reader, secret)
 | 
					 | 
				
			||||||
			assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// We test both padded & raw Base64 to ensure we handle both
 | 
					 | 
				
			||||||
			// potential user input routes for Base64
 | 
					 | 
				
			||||||
			base64Padded := base64.URLEncoding.EncodeToString(secret)
 | 
					 | 
				
			||||||
			sb := SecretBytes(base64Padded)
 | 
					 | 
				
			||||||
			assert.Equal(t, secret, sb)
 | 
					 | 
				
			||||||
			assert.Equal(t, len(sb), secretSize)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			base64Raw := base64.RawURLEncoding.EncodeToString(secret)
 | 
					 | 
				
			||||||
			sb = SecretBytes(base64Raw)
 | 
					 | 
				
			||||||
			assert.Equal(t, secret, sb)
 | 
					 | 
				
			||||||
			assert.Equal(t, len(sb), secretSize)
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// A string that isn't intended as Base64 and still decodes (but to unintended length)
 | 
					 | 
				
			||||||
// will return the original secret as bytes
 | 
					 | 
				
			||||||
func TestSecretBytesEncodedWrongSize(t *testing.T) {
 | 
					 | 
				
			||||||
	for _, secretSize := range []int{15, 20, 28, 33, 44} {
 | 
					 | 
				
			||||||
		t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
 | 
					 | 
				
			||||||
			secret := make([]byte, secretSize)
 | 
					 | 
				
			||||||
			_, err := io.ReadFull(rand.Reader, secret)
 | 
					 | 
				
			||||||
			assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// We test both padded & raw Base64 to ensure we handle both
 | 
					 | 
				
			||||||
			// potential user input routes for Base64
 | 
					 | 
				
			||||||
			base64Padded := base64.URLEncoding.EncodeToString(secret)
 | 
					 | 
				
			||||||
			sb := SecretBytes(base64Padded)
 | 
					 | 
				
			||||||
			assert.NotEqual(t, secret, sb)
 | 
					 | 
				
			||||||
			assert.NotEqual(t, len(sb), secretSize)
 | 
					 | 
				
			||||||
			// The given secret is returned as []byte
 | 
					 | 
				
			||||||
			assert.Equal(t, base64Padded, string(sb))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			base64Raw := base64.RawURLEncoding.EncodeToString(secret)
 | 
					 | 
				
			||||||
			sb = SecretBytes(base64Raw)
 | 
					 | 
				
			||||||
			assert.NotEqual(t, secret, sb)
 | 
					 | 
				
			||||||
			assert.NotEqual(t, len(sb), secretSize)
 | 
					 | 
				
			||||||
			// The given secret is returned as []byte
 | 
					 | 
				
			||||||
			assert.Equal(t, base64Raw, string(sb))
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestSecretBytesNonBase64(t *testing.T) {
 | 
					 | 
				
			||||||
	trailer := "equals=========="
 | 
					 | 
				
			||||||
	assert.Equal(t, trailer, string(SecretBytes(trailer)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	raw16 := "asdflkjhqwer)(*&"
 | 
					 | 
				
			||||||
	sb16 := SecretBytes(raw16)
 | 
					 | 
				
			||||||
	assert.Equal(t, raw16, string(sb16))
 | 
					 | 
				
			||||||
	assert.Equal(t, 16, len(sb16))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	raw24 := "asdflkjhqwer)(*&CJEN#$%^"
 | 
					 | 
				
			||||||
	sb24 := SecretBytes(raw24)
 | 
					 | 
				
			||||||
	assert.Equal(t, raw24, string(sb24))
 | 
					 | 
				
			||||||
	assert.Equal(t, 24, len(sb24))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&"
 | 
					 | 
				
			||||||
	sb32 := SecretBytes(raw32)
 | 
					 | 
				
			||||||
	assert.Equal(t, raw32, string(sb32))
 | 
					 | 
				
			||||||
	assert.Equal(t, 32, len(sb32))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestSignAndValidate(t *testing.T) {
 | 
					 | 
				
			||||||
	seed := "0123456789abcdef"
 | 
					 | 
				
			||||||
	key := "cookie-name"
 | 
					 | 
				
			||||||
	value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded"))
 | 
					 | 
				
			||||||
	epoch := "123456789"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sha256sig := cookieSignature(sha256.New, seed, key, value, epoch)
 | 
					 | 
				
			||||||
	sha1sig := cookieSignature(sha1.New, seed, key, value, epoch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	assert.True(t, checkSignature(sha256sig, seed, key, value, epoch))
 | 
					 | 
				
			||||||
	// This should be switched to False after fully deprecating SHA1
 | 
					 | 
				
			||||||
	assert.True(t, checkSignature(sha1sig, seed, key, value, epoch))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	assert.False(t, checkSignature(sha256sig, seed, key, "tampered", epoch))
 | 
					 | 
				
			||||||
	assert.False(t, checkSignature(sha1sig, seed, key, "tampered", epoch))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestEncodeAndDecodeAccessToken(t *testing.T) {
 | 
					func TestEncodeAndDecodeAccessToken(t *testing.T) {
 | 
				
			||||||
	const secret = "0123456789abcdefghijklmnopqrstuv"
 | 
						const secret = "0123456789abcdefghijklmnopqrstuv"
 | 
				
			||||||
	const token = "my access token"
 | 
						const token = "my access token"
 | 
				
			||||||
	c, err := NewCipher([]byte(secret))
 | 
						c, err := NewBase64Cipher(NewCFBCipher, []byte(secret))
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	encoded, err := c.Encrypt(token)
 | 
						encoded, err := c.Encrypt([]byte(token))
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	decoded, err := c.Decrypt(encoded)
 | 
						decoded, err := c.Decrypt(encoded)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.NotEqual(t, token, encoded)
 | 
						assert.NotEqual(t, []byte(token), encoded)
 | 
				
			||||||
	assert.Equal(t, token, decoded)
 | 
						assert.Equal(t, []byte(token), decoded)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
 | 
					func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
 | 
				
			||||||
| 
						 | 
					@ -121,37 +32,199 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	secret, err := base64.URLEncoding.DecodeString(secretBase64)
 | 
						secret, err := base64.URLEncoding.DecodeString(secretBase64)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	c, err := NewCipher([]byte(secret))
 | 
						c, err := NewBase64Cipher(NewCFBCipher, []byte(secret))
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	encoded, err := c.Encrypt(token)
 | 
						encoded, err := c.Encrypt([]byte(token))
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	decoded, err := c.Decrypt(encoded)
 | 
						decoded, err := c.Decrypt(encoded)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.NotEqual(t, token, encoded)
 | 
						assert.NotEqual(t, []byte(token), encoded)
 | 
				
			||||||
	assert.Equal(t, token, decoded)
 | 
						assert.Equal(t, []byte(token), decoded)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) {
 | 
					func TestEncryptAndDecrypt(t *testing.T) {
 | 
				
			||||||
	const secret = "0123456789abcdefghijklmnopqrstuv"
 | 
						// Test our 2 cipher types
 | 
				
			||||||
	c, err := NewCipher([]byte(secret))
 | 
						cipherInits := map[string]func([]byte) (Cipher, error){
 | 
				
			||||||
 | 
							"CFB": NewCFBCipher,
 | 
				
			||||||
 | 
							"GCM": NewGCMCipher,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for name, initCipher := range cipherInits {
 | 
				
			||||||
 | 
							t.Run(name, func(t *testing.T) {
 | 
				
			||||||
 | 
								// Test all 3 valid AES sizes
 | 
				
			||||||
 | 
								for _, secretSize := range []int{16, 24, 32} {
 | 
				
			||||||
 | 
									t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
 | 
				
			||||||
 | 
										secret := make([]byte, secretSize)
 | 
				
			||||||
 | 
										_, err := io.ReadFull(rand.Reader, secret)
 | 
				
			||||||
					assert.Equal(t, nil, err)
 | 
										assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	token := "my access token"
 | 
										// Test Standard & Base64 wrapped
 | 
				
			||||||
	originalToken := token
 | 
										cstd, err := initCipher(secret)
 | 
				
			||||||
 | 
										assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.Equal(t, nil, c.EncryptInto(&token))
 | 
										cb64, err := NewBase64Cipher(initCipher, secret)
 | 
				
			||||||
	assert.NotEqual(t, originalToken, token)
 | 
										assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	assert.Equal(t, nil, c.DecryptInto(&token))
 | 
										ciphers := map[string]Cipher{
 | 
				
			||||||
	assert.Equal(t, originalToken, token)
 | 
											"Standard": cstd,
 | 
				
			||||||
 | 
											"Base64":   cb64,
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check no errors with empty or nil strings
 | 
										for cName, c := range ciphers {
 | 
				
			||||||
	empty := ""
 | 
											t.Run(cName, func(t *testing.T) {
 | 
				
			||||||
	assert.Equal(t, nil, c.EncryptInto(&empty))
 | 
												// Test various sizes sessions might be
 | 
				
			||||||
	assert.Equal(t, nil, c.DecryptInto(&empty))
 | 
												for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
 | 
				
			||||||
	assert.Equal(t, nil, c.EncryptInto(nil))
 | 
													t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
 | 
				
			||||||
	assert.Equal(t, nil, c.DecryptInto(nil))
 | 
														runEncryptAndDecrypt(t, c, dataSize)
 | 
				
			||||||
 | 
													})
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											})
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) {
 | 
				
			||||||
 | 
						data := make([]byte, dataSize)
 | 
				
			||||||
 | 
						_, err := io.ReadFull(rand.Reader, data)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Ensure our Encrypt function doesn't encrypt in place
 | 
				
			||||||
 | 
						immutableData := make([]byte, len(data))
 | 
				
			||||||
 | 
						copy(immutableData, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						encrypted, err := c.Encrypt(data)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						assert.NotEqual(t, encrypted, data)
 | 
				
			||||||
 | 
						// Encrypt didn't operate in-place on []byte
 | 
				
			||||||
 | 
						assert.Equal(t, data, immutableData)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Ensure our Decrypt function doesn't decrypt in place
 | 
				
			||||||
 | 
						immutableEnc := make([]byte, len(encrypted))
 | 
				
			||||||
 | 
						copy(immutableEnc, encrypted)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						decrypted, err := c.Decrypt(encrypted)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						// Original data back
 | 
				
			||||||
 | 
						assert.Equal(t, data, decrypted)
 | 
				
			||||||
 | 
						// Decrypt didn't operate in-place on []byte
 | 
				
			||||||
 | 
						assert.Equal(t, encrypted, immutableEnc)
 | 
				
			||||||
 | 
						// Encrypt/Decrypt actually did something
 | 
				
			||||||
 | 
						assert.NotEqual(t, encrypted, decrypted)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestDecryptCFBWrongSecret(t *testing.T) {
 | 
				
			||||||
 | 
						secret1 := []byte("0123456789abcdefghijklmnopqrstuv")
 | 
				
			||||||
 | 
						secret2 := []byte("9876543210abcdefghijklmnopqrstuv")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c1, err := NewCFBCipher(secret1)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c2, err := NewCFBCipher(secret2)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ciphertext, err := c1.Encrypt(data)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wrongData, err := c2.Decrypt(ciphertext)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						assert.NotEqual(t, data, wrongData)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestDecryptGCMWrongSecret(t *testing.T) {
 | 
				
			||||||
 | 
						secret1 := []byte("0123456789abcdefghijklmnopqrstuv")
 | 
				
			||||||
 | 
						secret2 := []byte("9876543210abcdefghijklmnopqrstuv")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c1, err := NewGCMCipher(secret1)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c2, err := NewGCMCipher(secret2)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ciphertext, err := c1.Encrypt(data)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// GCM is authenticated - this should lead to message authentication failed
 | 
				
			||||||
 | 
						_, err = c2.Decrypt(ciphertext)
 | 
				
			||||||
 | 
						assert.Error(t, err)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Encrypt with GCM, Decrypt with CFB: Results in Garbage data
 | 
				
			||||||
 | 
					func TestGCMtoCFBErrors(t *testing.T) {
 | 
				
			||||||
 | 
						// Test all 3 valid AES sizes
 | 
				
			||||||
 | 
						for _, secretSize := range []int{16, 24, 32} {
 | 
				
			||||||
 | 
							t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
 | 
				
			||||||
 | 
								secret := make([]byte, secretSize)
 | 
				
			||||||
 | 
								_, err := io.ReadFull(rand.Reader, secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								gcm, err := NewGCMCipher(secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								cfb, err := NewCFBCipher(secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Test various sizes sessions might be
 | 
				
			||||||
 | 
								for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
 | 
				
			||||||
 | 
									t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
 | 
				
			||||||
 | 
										data := make([]byte, dataSize)
 | 
				
			||||||
 | 
										_, err := io.ReadFull(rand.Reader, data)
 | 
				
			||||||
 | 
										assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										encrypted, err := gcm.Encrypt(data)
 | 
				
			||||||
 | 
										assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
										assert.NotEqual(t, encrypted, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										decrypted, err := cfb.Decrypt(encrypted)
 | 
				
			||||||
 | 
										assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
										// Data is mangled
 | 
				
			||||||
 | 
										assert.NotEqual(t, data, decrypted)
 | 
				
			||||||
 | 
										assert.NotEqual(t, encrypted, decrypted)
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Encrypt with CFB, Decrypt with GCM: Results in errors
 | 
				
			||||||
 | 
					func TestCFBtoGCMErrors(t *testing.T) {
 | 
				
			||||||
 | 
						// Test all 3 valid AES sizes
 | 
				
			||||||
 | 
						for _, secretSize := range []int{16, 24, 32} {
 | 
				
			||||||
 | 
							t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
 | 
				
			||||||
 | 
								secret := make([]byte, secretSize)
 | 
				
			||||||
 | 
								_, err := io.ReadFull(rand.Reader, secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								gcm, err := NewGCMCipher(secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								cfb, err := NewCFBCipher(secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Test various sizes sessions might be
 | 
				
			||||||
 | 
								for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
 | 
				
			||||||
 | 
									t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
 | 
				
			||||||
 | 
										data := make([]byte, dataSize)
 | 
				
			||||||
 | 
										_, err := io.ReadFull(rand.Reader, data)
 | 
				
			||||||
 | 
										assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										encrypted, err := cfb.Encrypt(data)
 | 
				
			||||||
 | 
										assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
										assert.NotEqual(t, encrypted, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										// GCM is authenticated - this should lead to message authentication failed
 | 
				
			||||||
 | 
										_, err = gcm.Decrypt(encrypted)
 | 
				
			||||||
 | 
										assert.Error(t, err)
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,106 @@
 | 
				
			||||||
 | 
					package encryption
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/hmac"
 | 
				
			||||||
 | 
						"crypto/sha1"
 | 
				
			||||||
 | 
						"crypto/sha256"
 | 
				
			||||||
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"hash"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
 | 
				
			||||||
 | 
					func SecretBytes(secret string) []byte {
 | 
				
			||||||
 | 
						b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "="))
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							// Only return decoded form if a valid AES length
 | 
				
			||||||
 | 
							// Don't want unintentional decoding resulting in invalid lengths confusing a user
 | 
				
			||||||
 | 
							// that thought they used a 16, 24, 32 length string
 | 
				
			||||||
 | 
							for _, i := range []int{16, 24, 32} {
 | 
				
			||||||
 | 
								if len(b) == i {
 | 
				
			||||||
 | 
									return b
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// If decoding didn't work or resulted in non-AES compliant length,
 | 
				
			||||||
 | 
						// assume the raw string was the intended secret
 | 
				
			||||||
 | 
						return []byte(secret)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
 | 
				
			||||||
 | 
					// additionally, the 'value' is encrypted so it's opaque to the browser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Validate ensures a cookie is properly signed
 | 
				
			||||||
 | 
					func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value []byte, t time.Time, ok bool) {
 | 
				
			||||||
 | 
						// value, timestamp, sig
 | 
				
			||||||
 | 
						parts := strings.Split(cookie.Value, "|")
 | 
				
			||||||
 | 
						if len(parts) != 3 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if checkSignature(parts[2], seed, cookie.Name, parts[0], parts[1]) {
 | 
				
			||||||
 | 
							ts, err := strconv.Atoi(parts[1])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// The expiration timestamp set when the cookie was created
 | 
				
			||||||
 | 
							// isn't sent back by the browser. Hence, we check whether the
 | 
				
			||||||
 | 
							// creation timestamp stored in the cookie falls within the
 | 
				
			||||||
 | 
							// window defined by (Now()-expiration, Now()].
 | 
				
			||||||
 | 
							t = time.Unix(int64(ts), 0)
 | 
				
			||||||
 | 
							if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
 | 
				
			||||||
 | 
								// it's a valid cookie. now get the contents
 | 
				
			||||||
 | 
								rawValue, err := base64.URLEncoding.DecodeString(parts[0])
 | 
				
			||||||
 | 
								if err == nil {
 | 
				
			||||||
 | 
									value = rawValue
 | 
				
			||||||
 | 
									ok = true
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SignedValue returns a cookie that is signed and can later be checked with Validate
 | 
				
			||||||
 | 
					func SignedValue(seed string, key string, value []byte, now time.Time) string {
 | 
				
			||||||
 | 
						encodedValue := base64.URLEncoding.EncodeToString(value)
 | 
				
			||||||
 | 
						timeStr := fmt.Sprintf("%d", now.Unix())
 | 
				
			||||||
 | 
						sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr)
 | 
				
			||||||
 | 
						cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
 | 
				
			||||||
 | 
						return cookieVal
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func cookieSignature(signer func() hash.Hash, args ...string) string {
 | 
				
			||||||
 | 
						h := hmac.New(signer, []byte(args[0]))
 | 
				
			||||||
 | 
						for _, arg := range args[1:] {
 | 
				
			||||||
 | 
							h.Write([]byte(arg))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var b []byte
 | 
				
			||||||
 | 
						b = h.Sum(b)
 | 
				
			||||||
 | 
						return base64.URLEncoding.EncodeToString(b)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func checkSignature(signature string, args ...string) bool {
 | 
				
			||||||
 | 
						checkSig := cookieSignature(sha256.New, args...)
 | 
				
			||||||
 | 
						if checkHmac(signature, checkSig) {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO: After appropriate rollout window, remove support for SHA1
 | 
				
			||||||
 | 
						legacySig := cookieSignature(sha1.New, args...)
 | 
				
			||||||
 | 
						return checkHmac(signature, legacySig)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func checkHmac(input, expected string) bool {
 | 
				
			||||||
 | 
						inputMAC, err1 := base64.URLEncoding.DecodeString(input)
 | 
				
			||||||
 | 
						if err1 == nil {
 | 
				
			||||||
 | 
							expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
 | 
				
			||||||
 | 
							if err2 == nil {
 | 
				
			||||||
 | 
								return hmac.Equal(inputMAC, expectedMAC)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,100 @@
 | 
				
			||||||
 | 
					package encryption
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"crypto/sha1"
 | 
				
			||||||
 | 
						"crypto/sha256"
 | 
				
			||||||
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSecretBytesEncoded(t *testing.T) {
 | 
				
			||||||
 | 
						for _, secretSize := range []int{16, 24, 32} {
 | 
				
			||||||
 | 
							t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
 | 
				
			||||||
 | 
								secret := make([]byte, secretSize)
 | 
				
			||||||
 | 
								_, err := io.ReadFull(rand.Reader, secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// We test both padded & raw Base64 to ensure we handle both
 | 
				
			||||||
 | 
								// potential user input routes for Base64
 | 
				
			||||||
 | 
								base64Padded := base64.URLEncoding.EncodeToString(secret)
 | 
				
			||||||
 | 
								sb := SecretBytes(base64Padded)
 | 
				
			||||||
 | 
								assert.Equal(t, secret, sb)
 | 
				
			||||||
 | 
								assert.Equal(t, len(sb), secretSize)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								base64Raw := base64.RawURLEncoding.EncodeToString(secret)
 | 
				
			||||||
 | 
								sb = SecretBytes(base64Raw)
 | 
				
			||||||
 | 
								assert.Equal(t, secret, sb)
 | 
				
			||||||
 | 
								assert.Equal(t, len(sb), secretSize)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// A string that isn't intended as Base64 and still decodes (but to unintended length)
 | 
				
			||||||
 | 
					// will return the original secret as bytes
 | 
				
			||||||
 | 
					func TestSecretBytesEncodedWrongSize(t *testing.T) {
 | 
				
			||||||
 | 
						for _, secretSize := range []int{15, 20, 28, 33, 44} {
 | 
				
			||||||
 | 
							t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
 | 
				
			||||||
 | 
								secret := make([]byte, secretSize)
 | 
				
			||||||
 | 
								_, err := io.ReadFull(rand.Reader, secret)
 | 
				
			||||||
 | 
								assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// We test both padded & raw Base64 to ensure we handle both
 | 
				
			||||||
 | 
								// potential user input routes for Base64
 | 
				
			||||||
 | 
								base64Padded := base64.URLEncoding.EncodeToString(secret)
 | 
				
			||||||
 | 
								sb := SecretBytes(base64Padded)
 | 
				
			||||||
 | 
								assert.NotEqual(t, secret, sb)
 | 
				
			||||||
 | 
								assert.NotEqual(t, len(sb), secretSize)
 | 
				
			||||||
 | 
								// The given secret is returned as []byte
 | 
				
			||||||
 | 
								assert.Equal(t, base64Padded, string(sb))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								base64Raw := base64.RawURLEncoding.EncodeToString(secret)
 | 
				
			||||||
 | 
								sb = SecretBytes(base64Raw)
 | 
				
			||||||
 | 
								assert.NotEqual(t, secret, sb)
 | 
				
			||||||
 | 
								assert.NotEqual(t, len(sb), secretSize)
 | 
				
			||||||
 | 
								// The given secret is returned as []byte
 | 
				
			||||||
 | 
								assert.Equal(t, base64Raw, string(sb))
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSecretBytesNonBase64(t *testing.T) {
 | 
				
			||||||
 | 
						trailer := "equals=========="
 | 
				
			||||||
 | 
						assert.Equal(t, trailer, string(SecretBytes(trailer)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						raw16 := "asdflkjhqwer)(*&"
 | 
				
			||||||
 | 
						sb16 := SecretBytes(raw16)
 | 
				
			||||||
 | 
						assert.Equal(t, raw16, string(sb16))
 | 
				
			||||||
 | 
						assert.Equal(t, 16, len(sb16))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						raw24 := "asdflkjhqwer)(*&CJEN#$%^"
 | 
				
			||||||
 | 
						sb24 := SecretBytes(raw24)
 | 
				
			||||||
 | 
						assert.Equal(t, raw24, string(sb24))
 | 
				
			||||||
 | 
						assert.Equal(t, 24, len(sb24))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&"
 | 
				
			||||||
 | 
						sb32 := SecretBytes(raw32)
 | 
				
			||||||
 | 
						assert.Equal(t, raw32, string(sb32))
 | 
				
			||||||
 | 
						assert.Equal(t, 32, len(sb32))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSignAndValidate(t *testing.T) {
 | 
				
			||||||
 | 
						seed := "0123456789abcdef"
 | 
				
			||||||
 | 
						key := "cookie-name"
 | 
				
			||||||
 | 
						value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded"))
 | 
				
			||||||
 | 
						epoch := "123456789"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sha256sig := cookieSignature(sha256.New, seed, key, value, epoch)
 | 
				
			||||||
 | 
						sha1sig := cookieSignature(sha1.New, seed, key, value, epoch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.True(t, checkSignature(sha256sig, seed, key, value, epoch))
 | 
				
			||||||
 | 
						// This should be switched to False after fully deprecating SHA1
 | 
				
			||||||
 | 
						assert.True(t, checkSignature(sha1sig, seed, key, value, epoch))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.False(t, checkSignature(sha256sig, seed, key, "tampered", epoch))
 | 
				
			||||||
 | 
						assert.False(t, checkSignature(sha1sig, seed, key, "tampered", epoch))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -28,7 +28,7 @@ var _ sessions.SessionStore = &SessionStore{}
 | 
				
			||||||
// interface that stores sessions in client side cookies
 | 
					// interface that stores sessions in client side cookies
 | 
				
			||||||
type SessionStore struct {
 | 
					type SessionStore struct {
 | 
				
			||||||
	CookieOptions *options.CookieOptions
 | 
						CookieOptions *options.CookieOptions
 | 
				
			||||||
	CookieCipher  *encryption.Cipher
 | 
						CookieCipher  encryption.Cipher
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Save takes a sessions.SessionState and stores the information from it
 | 
					// Save takes a sessions.SessionState and stores the information from it
 | 
				
			||||||
| 
						 | 
					@ -59,7 +59,7 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
 | 
				
			||||||
		return nil, errors.New("cookie signature not valid")
 | 
							return nil, errors.New("cookie signature not valid")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	session, err := sessionFromCookie(val, s.CookieCipher)
 | 
						session, err := sessionFromCookie(string(val), s.CookieCipher)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -84,12 +84,12 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// cookieForSession serializes a session state for storage in a cookie
 | 
					// cookieForSession serializes a session state for storage in a cookie
 | 
				
			||||||
func cookieForSession(s *sessions.SessionState, c *encryption.Cipher) (string, error) {
 | 
					func cookieForSession(s *sessions.SessionState, c encryption.Cipher) (string, error) {
 | 
				
			||||||
	return s.EncodeSessionState(c)
 | 
						return s.EncodeSessionState(c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// sessionFromCookie deserializes a session from a cookie value
 | 
					// sessionFromCookie deserializes a session from a cookie value
 | 
				
			||||||
func sessionFromCookie(v string, c *encryption.Cipher) (s *sessions.SessionState, err error) {
 | 
					func sessionFromCookie(v string, c encryption.Cipher) (s *sessions.SessionState, err error) {
 | 
				
			||||||
	return sessions.DecodeSessionState(v, c)
 | 
						return sessions.DecodeSessionState(v, c)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -104,7 +104,7 @@ func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Reques
 | 
				
			||||||
// authentication details
 | 
					// authentication details
 | 
				
			||||||
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie {
 | 
					func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie {
 | 
				
			||||||
	if value != "" {
 | 
						if value != "" {
 | 
				
			||||||
		value = encryption.SignedValue(s.CookieOptions.Secret, s.CookieOptions.Name, value, now)
 | 
							value = encryption.SignedValue(s.CookieOptions.Secret, s.CookieOptions.Name, []byte(value), now)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c := s.makeCookie(req, s.CookieOptions.Name, value, s.CookieOptions.Expire, now)
 | 
						c := s.makeCookie(req, s.CookieOptions.Name, value, s.CookieOptions.Expire, now)
 | 
				
			||||||
	if len(c.Value) > 4096-len(s.CookieOptions.Name) {
 | 
						if len(c.Value) > 4096-len(s.CookieOptions.Name) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,7 +32,7 @@ type TicketData struct {
 | 
				
			||||||
// SessionStore is an implementation of the sessions.SessionStore
 | 
					// SessionStore is an implementation of the sessions.SessionStore
 | 
				
			||||||
// interface that stores sessions in redis
 | 
					// interface that stores sessions in redis
 | 
				
			||||||
type SessionStore struct {
 | 
					type SessionStore struct {
 | 
				
			||||||
	CookieCipher  *encryption.Cipher
 | 
						CookieCipher  encryption.Cipher
 | 
				
			||||||
	CookieOptions *options.CookieOptions
 | 
						CookieOptions *options.CookieOptions
 | 
				
			||||||
	Client        Client
 | 
						Client        Client
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -175,7 +175,7 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro
 | 
				
			||||||
		return nil, fmt.Errorf("cookie signature not valid")
 | 
							return nil, fmt.Errorf("cookie signature not valid")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	ctx := req.Context()
 | 
						ctx := req.Context()
 | 
				
			||||||
	session, err := store.loadSessionFromString(ctx, val)
 | 
						session, err := store.loadSessionFromString(ctx, string(val))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, fmt.Errorf("error loading session: %s", err)
 | 
							return nil, fmt.Errorf("error loading session: %s", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -237,7 +237,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// We only return an error if we had an issue with redis
 | 
						// We only return an error if we had an issue with redis
 | 
				
			||||||
	// If there's an issue decoding the ticket, ignore it
 | 
						// If there's an issue decoding the ticket, ignore it
 | 
				
			||||||
	ticket, _ := decodeTicket(store.CookieOptions.Name, val)
 | 
						ticket, _ := decodeTicket(store.CookieOptions.Name, string(val))
 | 
				
			||||||
	if ticket != nil {
 | 
						if ticket != nil {
 | 
				
			||||||
		ctx := req.Context()
 | 
							ctx := req.Context()
 | 
				
			||||||
		err := store.Client.Del(ctx, ticket.asHandle(store.CookieOptions.Name))
 | 
							err := store.Client.Del(ctx, ticket.asHandle(store.CookieOptions.Name))
 | 
				
			||||||
| 
						 | 
					@ -251,7 +251,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
 | 
				
			||||||
// makeCookie makes a cookie, signing the value if present
 | 
					// makeCookie makes a cookie, signing the value if present
 | 
				
			||||||
func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie {
 | 
					func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie {
 | 
				
			||||||
	if value != "" {
 | 
						if value != "" {
 | 
				
			||||||
		value = encryption.SignedValue(store.CookieOptions.Secret, store.CookieOptions.Name, value, now)
 | 
							value = encryption.SignedValue(store.CookieOptions.Secret, store.CookieOptions.Name, []byte(value), now)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return cookies.MakeCookieFromOptions(
 | 
						return cookies.MakeCookieFromOptions(
 | 
				
			||||||
		req,
 | 
							req,
 | 
				
			||||||
| 
						 | 
					@ -302,7 +302,7 @@ func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, e
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Valid cookie, decode the ticket
 | 
						// Valid cookie, decode the ticket
 | 
				
			||||||
	ticket, err := decodeTicket(store.CookieOptions.Name, val)
 | 
						ticket, err := decodeTicket(store.CookieOptions.Name, string(val))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		// If we can't decode the ticket we have to create a new one
 | 
							// If we can't decode the ticket we have to create a new one
 | 
				
			||||||
		return newTicket()
 | 
							return newTicket()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -170,7 +170,7 @@ var _ = Describe("NewSessionStore", func() {
 | 
				
			||||||
				BeforeEach(func() {
 | 
									BeforeEach(func() {
 | 
				
			||||||
					By("Using a valid cookie with a different providers session encoding")
 | 
										By("Using a valid cookie with a different providers session encoding")
 | 
				
			||||||
					broken := "BrokenSessionFromADifferentSessionImplementation"
 | 
										broken := "BrokenSessionFromADifferentSessionImplementation"
 | 
				
			||||||
					value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, broken, time.Now())
 | 
										value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, []byte(broken), time.Now())
 | 
				
			||||||
					cookie := cookiesapi.MakeCookieFromOptions(request, cookieOpts.Name, value, cookieOpts, cookieOpts.Expire, time.Now())
 | 
										cookie := cookiesapi.MakeCookieFromOptions(request, cookieOpts.Name, value, cookieOpts, cookieOpts.Expire, time.Now())
 | 
				
			||||||
					request.AddCookie(cookie)
 | 
										request.AddCookie(cookie)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -367,7 +367,7 @@ var _ = Describe("NewSessionStore", func() {
 | 
				
			||||||
				_, err := rand.Read(secret)
 | 
									_, err := rand.Read(secret)
 | 
				
			||||||
				Expect(err).ToNot(HaveOccurred())
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
				cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret)
 | 
									cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret)
 | 
				
			||||||
				cipher, err := encryption.NewCipher(encryption.SecretBytes(cookieOpts.Secret))
 | 
									cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret))
 | 
				
			||||||
				Expect(err).ToNot(HaveOccurred())
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
				Expect(cipher).ToNot(BeNil())
 | 
									Expect(cipher).ToNot(BeNil())
 | 
				
			||||||
				opts.Cipher = cipher
 | 
									opts.Cipher = cipher
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -38,7 +38,7 @@ func Validate(o *options.Options) error {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	msgs := make([]string, 0)
 | 
						msgs := make([]string, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var cipher *encryption.Cipher
 | 
						var cipher encryption.Cipher
 | 
				
			||||||
	if o.Cookie.Secret == "" {
 | 
						if o.Cookie.Secret == "" {
 | 
				
			||||||
		msgs = append(msgs, "missing setting: cookie-secret")
 | 
							msgs = append(msgs, "missing setting: cookie-secret")
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
| 
						 | 
					@ -62,7 +62,7 @@ func Validate(o *options.Options) error {
 | 
				
			||||||
					len(encryption.SecretBytes(o.Cookie.Secret)), suffix))
 | 
										len(encryption.SecretBytes(o.Cookie.Secret)), suffix))
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			var err error
 | 
								var err error
 | 
				
			||||||
			cipher, err = encryption.NewCipher(encryption.SecretBytes(o.Cookie.Secret))
 | 
								cipher, err = encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(o.Cookie.Secret))
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err))
 | 
									msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err))
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue