Centralize Ticket management of persistent stores (#682)
* Centralize Ticket management of persistent stores persistence package with Manager & Ticket will handle all the details about keys, secrets, ticket into cookies, etc. Persistent stores just need to pass Save, Load & Clear function handles to the persistent manager now. * Shift to persistence.Manager wrapping a persistence.Store * Break up the Redis client builder logic * Move error messages to Store from Manager * Convert ticket to private for Manager use only * Add persistence Manager & ticket tests * Make a custom MockStore that handles time FastForwards
This commit is contained in:
		
							parent
							
								
									f141f7cea0
								
							
						
					
					
						commit
						9643a0b10c
					
				|  | @ -11,6 +11,7 @@ | |||
| 
 | ||||
| ## Changes since v6.0.0 | ||||
| 
 | ||||
| - [#682](https://github.com/oauth2-proxy/oauth2-proxy/pull/682) Refactor persistent session store session ticket management (@NickMeves) | ||||
| - [#688](https://github.com/oauth2-proxy/oauth2-proxy/pull/688) Refactor session loading to make use of middleware pattern (@JoelSpeed) | ||||
| - [#593](https://github.com/oauth2-proxy/oauth2-proxy/pull/593) Integrate upstream package with OAuth2 Proxy (@JoelSpeed) | ||||
| - [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed) | ||||
|  |  | |||
|  | @ -9,6 +9,8 @@ import ( | |||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| const LegacyV5TestSecret = "0123456789abcdefghijklmnopqrstuv" | ||||
| 
 | ||||
| // LegacyV5TestCase provides V5 JSON based test cases for legacy fallback code
 | ||||
| type LegacyV5TestCase struct { | ||||
| 	Input  string | ||||
|  | @ -22,8 +24,6 @@ type LegacyV5TestCase struct { | |||
| //
 | ||||
| // TODO: Remove when this is deprecated (likely V7)
 | ||||
| func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encryption.Cipher, encryption.Cipher) { | ||||
| 	const secret = "0123456789abcdefghijklmnopqrstuv" | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	createdJSON, err := created.MarshalJSON() | ||||
| 	assert.NoError(t, err) | ||||
|  | @ -33,7 +33,7 @@ func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encrypt | |||
| 	assert.NoError(t, err) | ||||
| 	eString := string(eJSON) | ||||
| 
 | ||||
| 	cfbCipher, err := encryption.NewCFBCipher([]byte(secret)) | ||||
| 	cfbCipher, err := encryption.NewCFBCipher([]byte(LegacyV5TestSecret)) | ||||
| 	assert.NoError(t, err) | ||||
| 	legacyCipher := encryption.NewBase64Cipher(cfbCipher) | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,15 @@ | |||
| package persistence | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // Store is used for persistent session stores (IE not Cookie)
 | ||||
| // Implementing this interface allows it to easily use the persistence.Manager
 | ||||
| // for session ticket + encryption details.
 | ||||
| type Store interface { | ||||
| 	Save(context.Context, string, []byte, time.Duration) error | ||||
| 	Load(context.Context, string) ([]byte, error) | ||||
| 	Clear(context.Context, string) error | ||||
| } | ||||
|  | @ -0,0 +1,91 @@ | |||
| package persistence | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // Manager wraps a Store and handles the implementation details of the
 | ||||
| // sessions.SessionStore with its use of session tickets
 | ||||
| type Manager struct { | ||||
| 	Store   Store | ||||
| 	Options *options.Cookie | ||||
| } | ||||
| 
 | ||||
| // NewManager creates a Manager that can wrap a Store and manage the
 | ||||
| // sessions.SessionStore implementation details
 | ||||
| func NewManager(store Store, cookieOpts *options.Cookie) *Manager { | ||||
| 	return &Manager{ | ||||
| 		Store:   store, | ||||
| 		Options: cookieOpts, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Save saves a session in a persistent Store. Save will generate (or reuse an
 | ||||
| // existing) ticket which manages unique per session encryption & retrieval
 | ||||
| // from the persistent data store.
 | ||||
| func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | ||||
| 	if s.CreatedAt == nil || s.CreatedAt.IsZero() { | ||||
| 		now := time.Now() | ||||
| 		s.CreatedAt = &now | ||||
| 	} | ||||
| 
 | ||||
| 	tckt, err := decodeTicketFromRequest(req, m.Options) | ||||
| 	if err != nil { | ||||
| 		tckt, err = newTicket(m.Options) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error creating a session ticket: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = tckt.saveSession(s, func(key string, val []byte, exp time.Duration) error { | ||||
| 		return m.Store.Save(req.Context(), key, val, exp) | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	tckt.setCookie(rw, req, s) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Load reads sessions.SessionState information from a session store. It will
 | ||||
| // use the session ticket from the http.Request's cookie.
 | ||||
| func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) { | ||||
| 	tckt, err := decodeTicketFromRequest(req, m.Options) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return tckt.loadSession(func(key string) ([]byte, error) { | ||||
| 		return m.Store.Load(req.Context(), key) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // Clear clears any saved session information for a given ticket cookie.
 | ||||
| // Then it clears all session data for that ticket in the Store.
 | ||||
| func (m *Manager) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||
| 	tckt, err := decodeTicketFromRequest(req, m.Options) | ||||
| 	if err != nil { | ||||
| 		// Always clear the cookie, even when we can't load a cookie from
 | ||||
| 		// the request
 | ||||
| 		tckt = &ticket{ | ||||
| 			options: m.Options, | ||||
| 		} | ||||
| 		tckt.clearCookie(rw, req) | ||||
| 		// Don't raise an error if we didn't have a Cookie
 | ||||
| 		if err == http.ErrNoCookie { | ||||
| 			return nil | ||||
| 		} | ||||
| 		return fmt.Errorf("error decoding ticket to clear session: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	tckt.clearCookie(rw, req) | ||||
| 	return tckt.clearSession(func(key string) error { | ||||
| 		return m.Store.Clear(req.Context(), key) | ||||
| 	}) | ||||
| } | ||||
|  | @ -0,0 +1,34 @@ | |||
| package persistence | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| func TestManager(t *testing.T) { | ||||
| 	logger.SetOutput(GinkgoWriter) | ||||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "Persistence Manager SessionStore") | ||||
| } | ||||
| 
 | ||||
| var _ = Describe("Persistence Manager SessionStore Tests", func() { | ||||
| 	var ms *tests.MockStore | ||||
| 	BeforeEach(func() { | ||||
| 		ms = tests.NewMockStore() | ||||
| 	}) | ||||
| 	tests.RunSessionStoreTests( | ||||
| 		func(_ *options.SessionOptions, cookieOpts *options.Cookie) (sessionsapi.SessionStore, error) { | ||||
| 			return NewManager(ms, cookieOpts), nil | ||||
| 		}, | ||||
| 		func(d time.Duration) error { | ||||
| 			ms.FastForward(d) | ||||
| 			return nil | ||||
| 		}) | ||||
| }) | ||||
|  | @ -0,0 +1,221 @@ | |||
| package persistence | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/cookies" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | ||||
| ) | ||||
| 
 | ||||
| // saveFunc performs a persistent store's save functionality using
 | ||||
| // a key string, value []byte & (optional) expiration time.Duration
 | ||||
| type saveFunc func(string, []byte, time.Duration) error | ||||
| 
 | ||||
| // loadFunc performs a load from a persistent store using a
 | ||||
| // string key and returning the stored value as []byte
 | ||||
| type loadFunc func(string) ([]byte, error) | ||||
| 
 | ||||
| // clearFunc performs a persistent store's clear functionality using
 | ||||
| // a string key for the target of the deletion.
 | ||||
| type clearFunc func(string) error | ||||
| 
 | ||||
| // ticket is a structure representing the ticket used in server based
 | ||||
| // session storage. It provides a unique per session decryption secret giving
 | ||||
| // more security than the shared CookieSecret.
 | ||||
| type ticket struct { | ||||
| 	id      string | ||||
| 	secret  []byte | ||||
| 	options *options.Cookie | ||||
| } | ||||
| 
 | ||||
| // newTicket creates a new ticket. The ID & secret will be randomly created
 | ||||
| // with 16 byte sizes. The ID will be prefixed & hex encoded.
 | ||||
| func newTicket(cookieOpts *options.Cookie) (*ticket, error) { | ||||
| 	rawID := make([]byte, 16) | ||||
| 	if _, err := io.ReadFull(rand.Reader, rawID); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create new ticket ID: %v", err) | ||||
| 	} | ||||
| 	// ticketID is hex encoded
 | ||||
| 	ticketID := fmt.Sprintf("%s-%s", cookieOpts.Name, hex.EncodeToString(rawID)) | ||||
| 
 | ||||
| 	secret := make([]byte, aes.BlockSize) | ||||
| 	if _, err := io.ReadFull(rand.Reader, secret); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create encryption secret: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return &ticket{ | ||||
| 		id:      ticketID, | ||||
| 		secret:  secret, | ||||
| 		options: cookieOpts, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| // encodeTicket encodes the Ticket to a string for usage in cookies
 | ||||
| func (t *ticket) encodeTicket() string { | ||||
| 	return fmt.Sprintf("%s.%s", t.id, base64.RawURLEncoding.EncodeToString(t.secret)) | ||||
| } | ||||
| 
 | ||||
| // decodeTicket decodes an encoded ticket string
 | ||||
| func decodeTicket(encTicket string, cookieOpts *options.Cookie) (*ticket, error) { | ||||
| 	ticketParts := strings.Split(encTicket, ".") | ||||
| 	if len(ticketParts) != 2 { | ||||
| 		return nil, errors.New("failed to decode ticket") | ||||
| 	} | ||||
| 	ticketID, secretBase64 := ticketParts[0], ticketParts[1] | ||||
| 
 | ||||
| 	secret, err := base64.RawURLEncoding.DecodeString(secretBase64) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to decode encryption secret: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return &ticket{ | ||||
| 		id:      ticketID, | ||||
| 		secret:  secret, | ||||
| 		options: cookieOpts, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| // decodeTicketFromRequest retrieves a potential ticket cookie from a request
 | ||||
| // and decodes it to a ticket.
 | ||||
| func decodeTicketFromRequest(req *http.Request, cookieOpts *options.Cookie) (*ticket, error) { | ||||
| 	requestCookie, err := req.Cookie(cookieOpts.Name) | ||||
| 	if err != nil { | ||||
| 		// Don't wrap this error to allow `err == http.ErrNoCookie` checks
 | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// An existing cookie exists, try to retrieve the ticket
 | ||||
| 	val, _, ok := encryption.Validate(requestCookie, cookieOpts.Secret, cookieOpts.Expire) | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("session ticket cookie failed validation: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// Valid cookie, decode the ticket
 | ||||
| 	return decodeTicket(string(val), cookieOpts) | ||||
| } | ||||
| 
 | ||||
| // saveSession encodes the SessionState with the ticket's secret and persists
 | ||||
| // it to disk via the passed saveFunc.
 | ||||
| func (t *ticket) saveSession(s *sessions.SessionState, saver saveFunc) error { | ||||
| 	c, err := t.makeCipher() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	ciphertext, err := s.EncodeSessionState(c, false) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to encode the session state with the ticket: %v", err) | ||||
| 	} | ||||
| 	return saver(t.id, ciphertext, t.options.Expire) | ||||
| } | ||||
| 
 | ||||
| // loadSession loads a session from the disk store via the passed loadFunc
 | ||||
| // using the ticket.id as the key. It then decodes the SessionState using
 | ||||
| // ticket.secret to make the AES-GCM cipher.
 | ||||
| //
 | ||||
| // TODO (@NickMeves): Remove legacyV5LoadSession support in V7
 | ||||
| func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) { | ||||
| 	ciphertext, err := loader(t.id) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to load the session state with the ticket: %v", err) | ||||
| 	} | ||||
| 	c, err := t.makeCipher() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	ss, err := sessions.DecodeSessionState(ciphertext, c, false) | ||||
| 	if err != nil { | ||||
| 		return t.legacyV5LoadSession(ciphertext) | ||||
| 	} | ||||
| 	return ss, nil | ||||
| } | ||||
| 
 | ||||
| // clearSession uses the passed clearFunc to delete a session stored with a
 | ||||
| // key of ticket.id
 | ||||
| func (t *ticket) clearSession(clearer clearFunc) error { | ||||
| 	return clearer(t.id) | ||||
| } | ||||
| 
 | ||||
| // setCookie sets the encoded ticket as a cookie
 | ||||
| func (t *ticket) setCookie(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) { | ||||
| 	ticketCookie := t.makeCookie( | ||||
| 		req, | ||||
| 		t.encodeTicket(), | ||||
| 		t.options.Expire, | ||||
| 		*s.CreatedAt, | ||||
| 	) | ||||
| 
 | ||||
| 	http.SetCookie(rw, ticketCookie) | ||||
| } | ||||
| 
 | ||||
| // clearCookie removes any cookies that would be where this ticket
 | ||||
| // would set them
 | ||||
| func (t *ticket) clearCookie(rw http.ResponseWriter, req *http.Request) { | ||||
| 	clearCookie := t.makeCookie( | ||||
| 		req, | ||||
| 		"", | ||||
| 		time.Hour*-1, | ||||
| 		time.Now(), | ||||
| 	) | ||||
| 	http.SetCookie(rw, clearCookie) | ||||
| } | ||||
| 
 | ||||
| // makeCookie makes a cookie, signing the value if present
 | ||||
| func (t *ticket) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie { | ||||
| 	if value != "" { | ||||
| 		value = encryption.SignedValue(t.options.Secret, t.options.Name, []byte(value), now) | ||||
| 	} | ||||
| 	return cookies.MakeCookieFromOptions( | ||||
| 		req, | ||||
| 		t.options.Name, | ||||
| 		value, | ||||
| 		t.options, | ||||
| 		expires, | ||||
| 		now, | ||||
| 	) | ||||
| } | ||||
| 
 | ||||
| // makeCipher makes a AES-GCM cipher out of the ticket's secret
 | ||||
| func (t *ticket) makeCipher() (encryption.Cipher, error) { | ||||
| 	c, err := encryption.NewGCMCipher(t.secret) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to make an AES-GCM cipher from the ticket secret: %v", err) | ||||
| 	} | ||||
| 	return c, nil | ||||
| } | ||||
| 
 | ||||
| // legacyV5LoadSession loads a Redis session created in V5 with historical logic
 | ||||
| //
 | ||||
| // TODO (@NickMeves): Remove in V7
 | ||||
| func (t *ticket) legacyV5LoadSession(resultBytes []byte) (*sessions.SessionState, error) { | ||||
| 	block, err := aes.NewCipher(t.secret) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create a legacy AES-CFB cipher from the ticket secret: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	stream := cipher.NewCFBDecrypter(block, t.secret) | ||||
| 	stream.XORKeyStream(resultBytes, resultBytes) | ||||
| 
 | ||||
| 	cfbCipher, err := encryption.NewCFBCipher(encryption.SecretBytes(t.options.Secret)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	legacyCipher := encryption.NewBase64Cipher(cfbCipher) | ||||
| 
 | ||||
| 	session, err := sessions.LegacyV5DecodeSessionState(string(resultBytes), legacyCipher) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return session, nil | ||||
| } | ||||
|  | @ -0,0 +1,223 @@ | |||
| package persistence | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| func Test_ticket(t *testing.T) { | ||||
| 	logger.SetOutput(GinkgoWriter) | ||||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "Session Ticket") | ||||
| } | ||||
| 
 | ||||
| var _ = Describe("Session Ticket Tests", func() { | ||||
| 	Context("encodeTicket & decodeTicket", func() { | ||||
| 		type ticketTableInput struct { | ||||
| 			ticket        *ticket | ||||
| 			encodedTicket string | ||||
| 			expectedError error | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("encodeTicket should decodeTicket back when valid", | ||||
| 			func(in ticketTableInput) { | ||||
| 				if in.ticket != nil { | ||||
| 					enc := in.ticket.encodeTicket() | ||||
| 					Expect(enc).To(Equal(in.encodedTicket)) | ||||
| 
 | ||||
| 					dec, err := decodeTicket(enc, in.ticket.options) | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 					Expect(dec).To(Equal(in.ticket)) | ||||
| 				} else { | ||||
| 					_, err := decodeTicket(in.encodedTicket, nil) | ||||
| 					Expect(err).To(MatchError(in.expectedError)) | ||||
| 				} | ||||
| 			}, | ||||
| 			Entry("with a valid ticket", ticketTableInput{ | ||||
| 				ticket: &ticket{ | ||||
| 					id:     "dummy-0123456789abcdef", | ||||
| 					secret: []byte("0123456789abcdef"), | ||||
| 					options: &options.Cookie{ | ||||
| 						Name: "dummy", | ||||
| 					}, | ||||
| 				}, | ||||
| 				encodedTicket: fmt.Sprintf("%s.%s", | ||||
| 					"dummy-0123456789abcdef", | ||||
| 					base64.RawURLEncoding.EncodeToString([]byte("0123456789abcdef"))), | ||||
| 				expectedError: nil, | ||||
| 			}), | ||||
| 			Entry("with an invalid encoded ticket with 1 part", ticketTableInput{ | ||||
| 				ticket:        nil, | ||||
| 				encodedTicket: "dummy-0123456789abcdef", | ||||
| 				expectedError: errors.New("failed to decode ticket"), | ||||
| 			}), | ||||
| 			Entry("with an invalid base64 encoded secret", ticketTableInput{ | ||||
| 				ticket:        nil, | ||||
| 				encodedTicket: "dummy-0123456789abcdef.@)#($*@)#(*$@)#(*$", | ||||
| 				expectedError: fmt.Errorf("failed to decode encryption secret: illegal base64 data at input byte 0"), | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("saveSession", func() { | ||||
| 		It("uses the passed save function", func() { | ||||
| 			t, err := newTicket(&options.Cookie{Name: "dummy"}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			c, err := t.makeCipher() | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			ss := &sessions.SessionState{User: "foobar"} | ||||
| 			store := map[string][]byte{} | ||||
| 			err = t.saveSession(ss, func(k string, v []byte, e time.Duration) error { | ||||
| 				store[k] = v | ||||
| 				return nil | ||||
| 			}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			stored, err := sessions.DecodeSessionState(store[t.id], c, false) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(stored).To(Equal(ss)) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("errors when the saveFunc errors", func() { | ||||
| 			t, err := newTicket(&options.Cookie{Name: "dummy"}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			err = t.saveSession( | ||||
| 				&sessions.SessionState{User: "foobar"}, | ||||
| 				func(k string, v []byte, e time.Duration) error { | ||||
| 					return errors.New("save error") | ||||
| 				}) | ||||
| 			Expect(err).To(MatchError(errors.New("save error"))) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("loadSession", func() { | ||||
| 		It("uses the passed load function", func() { | ||||
| 			t, err := newTicket(&options.Cookie{Name: "dummy"}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			c, err := t.makeCipher() | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			ss := &sessions.SessionState{User: "foobar"} | ||||
| 			loadedSession, err := t.loadSession(func(k string) ([]byte, error) { | ||||
| 				return ss.EncodeSessionState(c, false) | ||||
| 			}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(loadedSession).To(Equal(ss)) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("errors when the loadFunc errors", func() { | ||||
| 			t, err := newTicket(&options.Cookie{Name: "dummy"}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			data, err := t.loadSession(func(k string) ([]byte, error) { | ||||
| 				return nil, errors.New("load error") | ||||
| 			}) | ||||
| 			Expect(data).To(BeNil()) | ||||
| 			Expect(err).To(MatchError(errors.New("failed to load the session state with the ticket: load error"))) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("clearSession", func() { | ||||
| 		It("uses the passed clear function", func() { | ||||
| 			t, err := newTicket(&options.Cookie{Name: "dummy"}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			var tracker string | ||||
| 			err = t.clearSession(func(k string) error { | ||||
| 				tracker = k | ||||
| 				return nil | ||||
| 			}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(tracker).To(Equal(t.id)) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("errors when the clearFunc errors", func() { | ||||
| 			t, err := newTicket(&options.Cookie{Name: "dummy"}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			err = t.clearSession(func(k string) error { | ||||
| 				return errors.New("clear error") | ||||
| 			}) | ||||
| 			Expect(err).To(MatchError(errors.New("clear error"))) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
| 
 | ||||
| // TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession
 | ||||
| // when a V5 encoded session is in Redis
 | ||||
| //
 | ||||
| // TODO (@NickMeves): Remove when this is deprecated (likely V7)
 | ||||
| func Test_legacyV5LoadSession(t *testing.T) { | ||||
| 	testCases, _, _ := sessions.CreateLegacyV5TestCases(t) | ||||
| 
 | ||||
| 	for testName, tc := range testCases { | ||||
| 		t.Run(testName, func(t *testing.T) { | ||||
| 			g := NewWithT(t) | ||||
| 
 | ||||
| 			secret := make([]byte, aes.BlockSize) | ||||
| 			_, err := io.ReadFull(rand.Reader, secret) | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 			tckt := &ticket{ | ||||
| 				secret: secret, | ||||
| 				options: &options.Cookie{ | ||||
| 					Secret: base64.RawURLEncoding.EncodeToString([]byte(sessions.LegacyV5TestSecret)), | ||||
| 				}, | ||||
| 			} | ||||
| 
 | ||||
| 			encrypted, err := legacyStoreValue(tc.Input, tckt.secret) | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			ss, err := tckt.legacyV5LoadSession(encrypted) | ||||
| 			if tc.Error { | ||||
| 				g.Expect(err).To(HaveOccurred()) | ||||
| 				g.Expect(ss).To(BeNil()) | ||||
| 				return | ||||
| 			} | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			// Compare sessions without *time.Time fields
 | ||||
| 			exp := *tc.Output | ||||
| 			exp.CreatedAt = nil | ||||
| 			exp.ExpiresOn = nil | ||||
| 			act := *ss | ||||
| 			act.CreatedAt = nil | ||||
| 			act.ExpiresOn = nil | ||||
| 			g.Expect(exp).To(Equal(act)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // legacyStoreValue implements the legacy V5 Redis persistence AES-CFB value encryption
 | ||||
| //
 | ||||
| // TODO (@NickMeves): Remove when this is deprecated (likely V7)
 | ||||
| func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) { | ||||
| 	ciphertext := make([]byte, len(value)) | ||||
| 	block, err := aes.NewCipher(ticketSecret) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error initiating cipher block: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// Use secret as the Initialization Vector too, because each entry has it's own key
 | ||||
| 	stream := cipher.NewCFBEncrypter(block, ticketSecret) | ||||
| 	stream.XORKeyStream(ciphertext, []byte(value)) | ||||
| 
 | ||||
| 	return ciphertext, nil | ||||
| } | ||||
|  | @ -2,69 +2,87 @@ package redis | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/x509" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/hex" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/go-redis/redis/v7" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/cookies" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence" | ||||
| ) | ||||
| 
 | ||||
| // TicketData is a structure representing the ticket used in server session storage
 | ||||
| type TicketData struct { | ||||
| 	TicketID string | ||||
| 	Secret   []byte | ||||
| } | ||||
| 
 | ||||
| // SessionStore is an implementation of the sessions.SessionStore
 | ||||
| // SessionStore is an implementation of the persistence.Store
 | ||||
| // interface that stores sessions in redis
 | ||||
| type SessionStore struct { | ||||
| 	CookieCipher encryption.Cipher | ||||
| 	Cookie       *options.Cookie | ||||
| 	Client Client | ||||
| } | ||||
| 
 | ||||
| // NewRedisSessionStore initialises a new instance of the SessionStore from
 | ||||
| // the configuration given
 | ||||
| // NewRedisSessionStore initialises a new instance of the SessionStore and wraps
 | ||||
| // it in a persistence.Manager
 | ||||
| func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) { | ||||
| 	cfbCipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret)) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error initialising cipher: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	client, err := newRedisCmdable(opts.Redis) | ||||
| 	client, err := newRedisClient(opts.Redis) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error constructing redis client: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	rs := &SessionStore{ | ||||
| 		Client: client, | ||||
| 		CookieCipher: cfbCipher, | ||||
| 		Cookie:       cookieOpts, | ||||
| 	} | ||||
| 	return rs, nil | ||||
| 
 | ||||
| 	return persistence.NewManager(rs, cookieOpts), nil | ||||
| } | ||||
| 
 | ||||
| func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) { | ||||
| // Save takes a sessions.SessionState and stores the information from it
 | ||||
| // to redies, and adds a new persistence cookie on the HTTP response writer
 | ||||
| func (store *SessionStore) Save(ctx context.Context, key string, value []byte, exp time.Duration) error { | ||||
| 	err := store.Client.Set(ctx, key, value, exp) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error saving redis session: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Load reads sessions.SessionState information from a persistence
 | ||||
| // cookie within the HTTP request object
 | ||||
| func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error) { | ||||
| 	value, err := store.Client.Get(ctx, key) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error loading redis session: %v", err) | ||||
| 	} | ||||
| 	return value, nil | ||||
| } | ||||
| 
 | ||||
| // Clear clears any saved session information for a given persistence cookie
 | ||||
| // from redis, and then clears the session
 | ||||
| func (store *SessionStore) Clear(ctx context.Context, key string) error { | ||||
| 	err := store.Client.Del(ctx, key) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error clearing the session from redis: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // newRedisClient makes a redis.Client (either standalone, sentinel aware, or
 | ||||
| // redis cluster)
 | ||||
| func newRedisClient(opts options.RedisStoreOptions) (Client, error) { | ||||
| 	if opts.UseSentinel && opts.UseCluster { | ||||
| 		return nil, fmt.Errorf("options redis-use-sentinel and redis-use-cluster are mutually exclusive") | ||||
| 	} | ||||
| 
 | ||||
| 	if opts.UseSentinel { | ||||
| 		return buildSentinelClient(opts) | ||||
| 	} | ||||
| 	if opts.UseCluster { | ||||
| 		return buildClusterClient(opts) | ||||
| 	} | ||||
| 
 | ||||
| 	return buildStandaloneClient(opts) | ||||
| } | ||||
| 
 | ||||
| // buildSentinelClient makes a redis.Client that connects to Redis Sentinel
 | ||||
| // for Primary/Replica Redis node coordination
 | ||||
| func buildSentinelClient(opts options.RedisStoreOptions) (Client, error) { | ||||
| 	addrs, err := parseRedisURLs(opts.SentinelConnectionURLs) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not parse redis urls: %v", err) | ||||
|  | @ -76,7 +94,8 @@ func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) { | |||
| 	return newClient(client), nil | ||||
| } | ||||
| 
 | ||||
| 	if opts.UseCluster { | ||||
| // buildClusterClient makes a redis.Client that is Redis Cluster aware
 | ||||
| func buildClusterClient(opts options.RedisStoreOptions) (Client, error) { | ||||
| 	addrs, err := parseRedisURLs(opts.ClusterConnectionURLs) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not parse redis urls: %v", err) | ||||
|  | @ -87,6 +106,9 @@ func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) { | |||
| 	return newClusterClient(client), nil | ||||
| } | ||||
| 
 | ||||
| // buildStandaloneClient makes a redis.Client that connects to a simple
 | ||||
| // Redis node
 | ||||
| func buildStandaloneClient(opts options.RedisStoreOptions) (Client, error) { | ||||
| 	opt, err := redis.ParseURL(opts.ConnectionURL) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("unable to parse redis url: %s", err) | ||||
|  | @ -134,261 +156,3 @@ func parseRedisURLs(urls []string) ([]string, error) { | |||
| 	} | ||||
| 	return addrs, nil | ||||
| } | ||||
| 
 | ||||
| // Save takes a sessions.SessionState and stores the information from it
 | ||||
| // to redies, and adds a new ticket cookie on the HTTP response writer
 | ||||
| func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | ||||
| 	if s.CreatedAt == nil || s.CreatedAt.IsZero() { | ||||
| 		now := time.Now() | ||||
| 		s.CreatedAt = &now | ||||
| 	} | ||||
| 
 | ||||
| 	// Old sessions that we are refreshing would have a request cookie
 | ||||
| 	// New sessions don't, so we ignore the error. storeValue will check requestCookie
 | ||||
| 	requestCookie, _ := req.Cookie(store.Cookie.Name) | ||||
| 	ctx := req.Context() | ||||
| 	ticketString, err := store.saveSession(ctx, s, store.Cookie.Expire, requestCookie) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	ticketCookie := store.makeCookie( | ||||
| 		req, | ||||
| 		ticketString, | ||||
| 		store.Cookie.Expire, | ||||
| 		*s.CreatedAt, | ||||
| 	) | ||||
| 
 | ||||
| 	http.SetCookie(rw, ticketCookie) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Load reads sessions.SessionState information from a ticket
 | ||||
| // cookie within the HTTP request object
 | ||||
| func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { | ||||
| 	requestCookie, err := req.Cookie(store.Cookie.Name) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error loading session: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire) | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("cookie signature not valid") | ||||
| 	} | ||||
| 	ctx := req.Context() | ||||
| 	session, err := store.loadSessionFromTicket(ctx, string(val)) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error loading session: %s", err) | ||||
| 	} | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| // Clear clears any saved session information for a given ticket cookie
 | ||||
| // from redis, and then clears the session
 | ||||
| func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||
| 	// We go ahead and clear the cookie first, always.
 | ||||
| 	clearCookie := store.makeCookie( | ||||
| 		req, | ||||
| 		"", | ||||
| 		time.Hour*-1, | ||||
| 		time.Now(), | ||||
| 	) | ||||
| 	http.SetCookie(rw, clearCookie) | ||||
| 
 | ||||
| 	// If there was an existing cookie we should clear the session in redis
 | ||||
| 	requestCookie, err := req.Cookie(store.Cookie.Name) | ||||
| 	if err != nil && err == http.ErrNoCookie { | ||||
| 		// No existing cookie so can't clear redis
 | ||||
| 		return nil | ||||
| 	} else if err != nil { | ||||
| 		return fmt.Errorf("error retrieving cookie: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("cookie signature not valid") | ||||
| 	} | ||||
| 
 | ||||
| 	// We only return an error if we had an issue with redis
 | ||||
| 	// If there's an issue decoding the ticket, ignore it
 | ||||
| 	ticket, _ := decodeTicket(store.Cookie.Name, string(val)) | ||||
| 	if ticket != nil { | ||||
| 		ctx := req.Context() | ||||
| 		err := store.Client.Del(ctx, ticket.asHandle(store.Cookie.Name)) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error clearing cookie from redis: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // saveSession encodes a session with a GCM cipher & saves the data into Redis
 | ||||
| func (store *SessionStore) saveSession(ctx context.Context, s *sessions.SessionState, expiration time.Duration, requestCookie *http.Cookie) (string, error) { | ||||
| 	ticket, err := store.getTicket(requestCookie) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error getting ticket: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	c, err := encryption.NewGCMCipher(ticket.Secret) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error initiating cipher block %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// Use AES-GCM since it provides authenticated encryption
 | ||||
| 	// AES-CFB used in cookies has the cookie signing SHA to get around the lack of
 | ||||
| 	// authentication in AES-CFB
 | ||||
| 	ciphertext, err := s.EncodeSessionState(c, false) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	handle := ticket.asHandle(store.Cookie.Name) | ||||
| 	err = store.Client.Set(ctx, handle, ciphertext, expiration) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return ticket.encodeTicket(store.Cookie.Name), nil | ||||
| } | ||||
| 
 | ||||
| // loadSessionFromTicket loads the session based on the ticket value
 | ||||
| func (store *SessionStore) loadSessionFromTicket(ctx context.Context, value string) (*sessions.SessionState, error) { | ||||
| 	ticket, err := decodeTicket(store.Cookie.Name, value) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	resultBytes, err := store.Client.Get(ctx, ticket.asHandle(store.Cookie.Name)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	c, err := encryption.NewGCMCipher(ticket.Secret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	session, err := sessions.DecodeSessionState(resultBytes, c, false) | ||||
| 	if err != nil { | ||||
| 		// The GCM cipher will error due to a legacy JSON payload not passing
 | ||||
| 		// the authentication check part of AES GCM encryption.
 | ||||
| 		// In that case, we can attempt to fallback to try a legacy load
 | ||||
| 		legacyCipher := encryption.NewBase64Cipher(store.CookieCipher) | ||||
| 		return legacyV5DecodeSession(resultBytes, ticket, legacyCipher) | ||||
| 	} | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| // legacyV5DecodeSession loads the session based on the ticket value
 | ||||
| // This fallback uses V5 style encryption of Base64 + AES CFB
 | ||||
| func legacyV5DecodeSession(resultBytes []byte, ticket *TicketData, c encryption.Cipher) (*sessions.SessionState, error) { | ||||
| 	block, err := aes.NewCipher(ticket.Secret) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	// Use secret as the IV too, because each entry has it's own key
 | ||||
| 	stream := cipher.NewCFBDecrypter(block, ticket.Secret) | ||||
| 	stream.XORKeyStream(resultBytes, resultBytes) | ||||
| 
 | ||||
| 	session, err := sessions.LegacyV5DecodeSessionState(string(resultBytes), c) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| // makeCookie makes a cookie, signing the value if present
 | ||||
| func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie { | ||||
| 	if value != "" { | ||||
| 		value = encryption.SignedValue(store.Cookie.Secret, store.Cookie.Name, []byte(value), now) | ||||
| 	} | ||||
| 	return cookies.MakeCookieFromOptions( | ||||
| 		req, | ||||
| 		store.Cookie.Name, | ||||
| 		value, | ||||
| 		store.Cookie, | ||||
| 		expires, | ||||
| 		now, | ||||
| 	) | ||||
| } | ||||
| 
 | ||||
| // getTicket retrieves an existing ticket from the cookie if present,
 | ||||
| // or creates a new ticket
 | ||||
| func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, error) { | ||||
| 	if requestCookie == nil { | ||||
| 		return newTicket() | ||||
| 	} | ||||
| 
 | ||||
| 	// An existing cookie exists, try to retrieve the ticket
 | ||||
| 	val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire) | ||||
| 	if !ok { | ||||
| 		// Cookie is invalid, create a new ticket
 | ||||
| 		return newTicket() | ||||
| 	} | ||||
| 
 | ||||
| 	// Valid cookie, decode the ticket
 | ||||
| 	ticket, err := decodeTicket(store.Cookie.Name, string(val)) | ||||
| 	if err != nil { | ||||
| 		// If we can't decode the ticket we have to create a new one
 | ||||
| 		return newTicket() | ||||
| 	} | ||||
| 	return ticket, nil | ||||
| } | ||||
| 
 | ||||
| func newTicket() (*TicketData, error) { | ||||
| 	rawID := make([]byte, 16) | ||||
| 	if _, err := io.ReadFull(rand.Reader, rawID); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create new ticket ID %s", err) | ||||
| 	} | ||||
| 	// ticketID is hex encoded
 | ||||
| 	ticketID := hex.EncodeToString(rawID) | ||||
| 
 | ||||
| 	secret := make([]byte, aes.BlockSize) | ||||
| 	if _, err := io.ReadFull(rand.Reader, secret); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create initialization vector %s", err) | ||||
| 	} | ||||
| 	ticket := &TicketData{ | ||||
| 		TicketID: ticketID, | ||||
| 		Secret:   secret, | ||||
| 	} | ||||
| 	return ticket, nil | ||||
| } | ||||
| 
 | ||||
| func (ticket *TicketData) asHandle(prefix string) string { | ||||
| 	return fmt.Sprintf("%s-%s", prefix, ticket.TicketID) | ||||
| } | ||||
| 
 | ||||
| func decodeTicket(cookieName string, ticketString string) (*TicketData, error) { | ||||
| 	prefix := cookieName + "-" | ||||
| 	if !strings.HasPrefix(ticketString, prefix) { | ||||
| 		return nil, fmt.Errorf("failed to decode ticket handle") | ||||
| 	} | ||||
| 	trimmedTicket := strings.TrimPrefix(ticketString, prefix) | ||||
| 
 | ||||
| 	ticketParts := strings.Split(trimmedTicket, ".") | ||||
| 	if len(ticketParts) != 2 { | ||||
| 		return nil, fmt.Errorf("failed to decode ticket") | ||||
| 	} | ||||
| 	ticketID, secretBase64 := ticketParts[0], ticketParts[1] | ||||
| 
 | ||||
| 	// ticketID must be a hexadecimal string
 | ||||
| 	_, err := hex.DecodeString(ticketID) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("server ticket failed sanity checks") | ||||
| 	} | ||||
| 
 | ||||
| 	secret, err := base64.RawURLEncoding.DecodeString(secretBase64) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to decode initialization vector %s", err) | ||||
| 	} | ||||
| 	ticketData := &TicketData{ | ||||
| 		TicketID: ticketID, | ||||
| 		Secret:   secret, | ||||
| 	} | ||||
| 	return ticketData, nil | ||||
| } | ||||
| 
 | ||||
| func (ticket *TicketData) encodeTicket(prefix string) string { | ||||
| 	handle := ticket.asHandle(prefix) | ||||
| 	ticketString := handle + "." + base64.RawURLEncoding.EncodeToString(ticket.Secret) | ||||
| 	return ticketString | ||||
| } | ||||
|  |  | |||
|  | @ -1,11 +1,6 @@ | |||
| package redis | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"crypto/rand" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"testing" | ||||
|  | @ -17,70 +12,12 @@ import ( | |||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| // TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession
 | ||||
| // when a V5 encoded session is in Redis
 | ||||
| //
 | ||||
| // TODO: Remove when this is deprecated (likely V7)
 | ||||
| func Test_legacyV5DecodeSession(t *testing.T) { | ||||
| 	testCases, _, legacyCipher := sessionsapi.CreateLegacyV5TestCases(t) | ||||
| 
 | ||||
| 	for testName, tc := range testCases { | ||||
| 		t.Run(testName, func(t *testing.T) { | ||||
| 			g := NewWithT(t) | ||||
| 
 | ||||
| 			secret := make([]byte, aes.BlockSize) | ||||
| 			_, err := io.ReadFull(rand.Reader, secret) | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 			ticket := &TicketData{ | ||||
| 				TicketID: "", | ||||
| 				Secret:   secret, | ||||
| 			} | ||||
| 
 | ||||
| 			encrypted, err := legacyStoreValue(tc.Input, ticket.Secret) | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			ss, err := legacyV5DecodeSession(encrypted, ticket, legacyCipher) | ||||
| 			if tc.Error { | ||||
| 				g.Expect(err).To(HaveOccurred()) | ||||
| 				g.Expect(ss).To(BeNil()) | ||||
| 				return | ||||
| 			} | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			// Compare sessions without *time.Time fields
 | ||||
| 			exp := *tc.Output | ||||
| 			exp.CreatedAt = nil | ||||
| 			exp.ExpiresOn = nil | ||||
| 			act := *ss | ||||
| 			act.CreatedAt = nil | ||||
| 			act.ExpiresOn = nil | ||||
| 			g.Expect(exp).To(Equal(act)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // legacyStoreValue implements the legacy V5 Redis store AES-CFB value encryption
 | ||||
| //
 | ||||
| // TODO: Remove when this is deprecated (likely V7)
 | ||||
| func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) { | ||||
| 	ciphertext := make([]byte, len(value)) | ||||
| 	block, err := aes.NewCipher(ticketSecret) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error initiating cipher block: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// Use secret as the Initialization Vector too, because each entry has it's own key
 | ||||
| 	stream := cipher.NewCFBEncrypter(block, ticketSecret) | ||||
| 	stream.XORKeyStream(ciphertext, []byte(value)) | ||||
| 
 | ||||
| 	return ciphertext, nil | ||||
| } | ||||
| 
 | ||||
| func TestSessionStore(t *testing.T) { | ||||
| 	logger.SetOutput(GinkgoWriter) | ||||
| 
 | ||||
|  | @ -114,9 +51,9 @@ var _ = Describe("Redis SessionStore Tests", func() { | |||
| 
 | ||||
| 	JustAfterEach(func() { | ||||
| 		// Release any connections immediately after the test ends
 | ||||
| 		if redisStore, ok := ss.(*SessionStore); ok { | ||||
| 			if redisStore.Client != nil { | ||||
| 				Expect(redisStore.Client.(closer).Close()).To(Succeed()) | ||||
| 		if redisManager, ok := ss.(*persistence.Manager); ok { | ||||
| 			if redisManager.Store.(*SessionStore).Client != nil { | ||||
| 				Expect(redisManager.Store.(*SessionStore).Client.(closer).Close()).To(Succeed()) | ||||
| 			} | ||||
| 		} | ||||
| 	}) | ||||
|  |  | |||
|  | @ -10,6 +10,7 @@ import ( | |||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" | ||||
| 	sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/redis" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
|  | @ -66,10 +67,11 @@ var _ = Describe("NewSessionStore", func() { | |||
| 			opts.Redis.ConnectionURL = "redis://" | ||||
| 		}) | ||||
| 
 | ||||
| 		It("creates a redis.SessionStore", func() { | ||||
| 		It("creates a persistence.Manager that wraps a redis.SessionStore", func() { | ||||
| 			ss, err := sessions.NewSessionStore(opts, cookieOpts) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 			Expect(ss).To(BeAssignableToTypeOf(&redis.SessionStore{})) | ||||
| 			Expect(ss).To(BeAssignableToTypeOf(&persistence.Manager{})) | ||||
| 			Expect(ss.(*persistence.Manager).Store).To(BeAssignableToTypeOf(&redis.SessionStore{})) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,58 @@ | |||
| package tests | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // entry is a MockStore cache entry with an expiration
 | ||||
| type entry struct { | ||||
| 	data       []byte | ||||
| 	expiration time.Duration | ||||
| } | ||||
| 
 | ||||
| // MockStore is a generic in-memory implementation of persistence.Store
 | ||||
| // for mocking in tests
 | ||||
| type MockStore struct { | ||||
| 	cache   map[string]entry | ||||
| 	elapsed time.Duration | ||||
| } | ||||
| 
 | ||||
| // NewMockStore creates a MockStore
 | ||||
| func NewMockStore() *MockStore { | ||||
| 	return &MockStore{ | ||||
| 		cache:   map[string]entry{}, | ||||
| 		elapsed: 0 * time.Second, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Save sets a key to the data to the memory cache
 | ||||
| func (s *MockStore) Save(_ context.Context, key string, value []byte, exp time.Duration) error { | ||||
| 	s.cache[key] = entry{ | ||||
| 		data:       value, | ||||
| 		expiration: exp, | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Load gets data from the memory cache via a key
 | ||||
| func (s *MockStore) Load(_ context.Context, key string) ([]byte, error) { | ||||
| 	entry, ok := s.cache[key] | ||||
| 	if !ok || entry.expiration <= s.elapsed { | ||||
| 		delete(s.cache, key) | ||||
| 		return nil, fmt.Errorf("key not found: %s", key) | ||||
| 	} | ||||
| 	return entry.data, nil | ||||
| } | ||||
| 
 | ||||
| // Clear deletes an entry from the memory cache
 | ||||
| func (s *MockStore) Clear(_ context.Context, key string) error { | ||||
| 	delete(s.cache, key) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // FastForward simulates the flow of time to test expirations
 | ||||
| func (s *MockStore) FastForward(duration time.Duration) { | ||||
| 	s.elapsed += duration | ||||
| } | ||||
|  | @ -133,18 +133,6 @@ func RunSessionStoreTests(newSS NewSessionStoreFunc, persistentFastForward Persi | |||
| 				PersistentSessionStoreInterfaceTests(&input) | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with an invalid cookie secret", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				input.cookieOpts.Secret = "invalid" | ||||
| 			}) | ||||
| 
 | ||||
| 			It("returns an error when initialising the session store", func() { | ||||
| 				ss, err := newSS(opts, input.cookieOpts) | ||||
| 				Expect(err).To(MatchError("error initialising cipher: crypto/aes: invalid key size 7")) | ||||
| 				Expect(ss).To(BeNil()) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue