From 9643a0b10c9bac2899a241d94e66ac2a64cf6d57 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sun, 19 Jul 2020 13:25:13 -0700 Subject: [PATCH] Centralize Ticket management of persistent stores (#682) * Centralize Ticket management of persistent stores persistence package with Manager & Ticket will handle all the details about keys, secrets, ticket into cookies, etc. Persistent stores just need to pass Save, Load & Clear function handles to the persistent manager now. * Shift to persistence.Manager wrapping a persistence.Store * Break up the Redis client builder logic * Move error messages to Store from Manager * Convert ticket to private for Manager use only * Add persistence Manager & ticket tests * Make a custom MockStore that handles time FastForwards --- CHANGELOG.md | 1 + pkg/apis/sessions/legacy_v5_tester.go | 6 +- pkg/sessions/persistence/interfaces.go | 15 + pkg/sessions/persistence/manager.go | 91 +++++ pkg/sessions/persistence/manager_test.go | 34 ++ pkg/sessions/persistence/ticket.go | 221 +++++++++++++ pkg/sessions/persistence/ticket_test.go | 223 +++++++++++++ pkg/sessions/redis/redis_store.go | 386 +++++----------------- pkg/sessions/redis/redis_store_test.go | 71 +--- pkg/sessions/session_store_test.go | 6 +- pkg/sessions/tests/mock_store.go | 58 ++++ pkg/sessions/tests/session_store_tests.go | 12 - 12 files changed, 729 insertions(+), 395 deletions(-) create mode 100644 pkg/sessions/persistence/interfaces.go create mode 100644 pkg/sessions/persistence/manager.go create mode 100644 pkg/sessions/persistence/manager_test.go create mode 100644 pkg/sessions/persistence/ticket.go create mode 100644 pkg/sessions/persistence/ticket_test.go create mode 100644 pkg/sessions/tests/mock_store.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 341b20d3..74215868 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ## Changes since v6.0.0 +- [#682](https://github.com/oauth2-proxy/oauth2-proxy/pull/682) Refactor persistent session store session ticket management (@NickMeves) - [#688](https://github.com/oauth2-proxy/oauth2-proxy/pull/688) Refactor session loading to make use of middleware pattern (@JoelSpeed) - [#593](https://github.com/oauth2-proxy/oauth2-proxy/pull/593) Integrate upstream package with OAuth2 Proxy (@JoelSpeed) - [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed) diff --git a/pkg/apis/sessions/legacy_v5_tester.go b/pkg/apis/sessions/legacy_v5_tester.go index 44cf4e73..25bc0b0c 100644 --- a/pkg/apis/sessions/legacy_v5_tester.go +++ b/pkg/apis/sessions/legacy_v5_tester.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/assert" ) +const LegacyV5TestSecret = "0123456789abcdefghijklmnopqrstuv" + // LegacyV5TestCase provides V5 JSON based test cases for legacy fallback code type LegacyV5TestCase struct { Input string @@ -22,8 +24,6 @@ type LegacyV5TestCase struct { // // TODO: Remove when this is deprecated (likely V7) func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encryption.Cipher, encryption.Cipher) { - const secret = "0123456789abcdefghijklmnopqrstuv" - created := time.Now() createdJSON, err := created.MarshalJSON() assert.NoError(t, err) @@ -33,7 +33,7 @@ func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encrypt assert.NoError(t, err) eString := string(eJSON) - cfbCipher, err := encryption.NewCFBCipher([]byte(secret)) + cfbCipher, err := encryption.NewCFBCipher([]byte(LegacyV5TestSecret)) assert.NoError(t, err) legacyCipher := encryption.NewBase64Cipher(cfbCipher) diff --git a/pkg/sessions/persistence/interfaces.go b/pkg/sessions/persistence/interfaces.go new file mode 100644 index 00000000..cd983b4c --- /dev/null +++ b/pkg/sessions/persistence/interfaces.go @@ -0,0 +1,15 @@ +package persistence + +import ( + "context" + "time" +) + +// Store is used for persistent session stores (IE not Cookie) +// Implementing this interface allows it to easily use the persistence.Manager +// for session ticket + encryption details. +type Store interface { + Save(context.Context, string, []byte, time.Duration) error + Load(context.Context, string) ([]byte, error) + Clear(context.Context, string) error +} diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go new file mode 100644 index 00000000..16a54b4e --- /dev/null +++ b/pkg/sessions/persistence/manager.go @@ -0,0 +1,91 @@ +package persistence + +import ( + "fmt" + "net/http" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" +) + +// Manager wraps a Store and handles the implementation details of the +// sessions.SessionStore with its use of session tickets +type Manager struct { + Store Store + Options *options.Cookie +} + +// NewManager creates a Manager that can wrap a Store and manage the +// sessions.SessionStore implementation details +func NewManager(store Store, cookieOpts *options.Cookie) *Manager { + return &Manager{ + Store: store, + Options: cookieOpts, + } +} + +// Save saves a session in a persistent Store. Save will generate (or reuse an +// existing) ticket which manages unique per session encryption & retrieval +// from the persistent data store. +func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { + if s.CreatedAt == nil || s.CreatedAt.IsZero() { + now := time.Now() + s.CreatedAt = &now + } + + tckt, err := decodeTicketFromRequest(req, m.Options) + if err != nil { + tckt, err = newTicket(m.Options) + if err != nil { + return fmt.Errorf("error creating a session ticket: %v", err) + } + } + + err = tckt.saveSession(s, func(key string, val []byte, exp time.Duration) error { + return m.Store.Save(req.Context(), key, val, exp) + }) + if err != nil { + return err + } + tckt.setCookie(rw, req, s) + + return nil +} + +// Load reads sessions.SessionState information from a session store. It will +// use the session ticket from the http.Request's cookie. +func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) { + tckt, err := decodeTicketFromRequest(req, m.Options) + if err != nil { + return nil, err + } + + return tckt.loadSession(func(key string) ([]byte, error) { + return m.Store.Load(req.Context(), key) + }) +} + +// Clear clears any saved session information for a given ticket cookie. +// Then it clears all session data for that ticket in the Store. +func (m *Manager) Clear(rw http.ResponseWriter, req *http.Request) error { + tckt, err := decodeTicketFromRequest(req, m.Options) + if err != nil { + // Always clear the cookie, even when we can't load a cookie from + // the request + tckt = &ticket{ + options: m.Options, + } + tckt.clearCookie(rw, req) + // Don't raise an error if we didn't have a Cookie + if err == http.ErrNoCookie { + return nil + } + return fmt.Errorf("error decoding ticket to clear session: %v", err) + } + + tckt.clearCookie(rw, req) + return tckt.clearSession(func(key string) error { + return m.Store.Clear(req.Context(), key) + }) +} diff --git a/pkg/sessions/persistence/manager_test.go b/pkg/sessions/persistence/manager_test.go new file mode 100644 index 00000000..77a3b470 --- /dev/null +++ b/pkg/sessions/persistence/manager_test.go @@ -0,0 +1,34 @@ +package persistence + +import ( + "testing" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestManager(t *testing.T) { + logger.SetOutput(GinkgoWriter) + RegisterFailHandler(Fail) + RunSpecs(t, "Persistence Manager SessionStore") +} + +var _ = Describe("Persistence Manager SessionStore Tests", func() { + var ms *tests.MockStore + BeforeEach(func() { + ms = tests.NewMockStore() + }) + tests.RunSessionStoreTests( + func(_ *options.SessionOptions, cookieOpts *options.Cookie) (sessionsapi.SessionStore, error) { + return NewManager(ms, cookieOpts), nil + }, + func(d time.Duration) error { + ms.FastForward(d) + return nil + }) +}) diff --git a/pkg/sessions/persistence/ticket.go b/pkg/sessions/persistence/ticket.go new file mode 100644 index 00000000..da1097d9 --- /dev/null +++ b/pkg/sessions/persistence/ticket.go @@ -0,0 +1,221 @@ +package persistence + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies" + "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" +) + +// saveFunc performs a persistent store's save functionality using +// a key string, value []byte & (optional) expiration time.Duration +type saveFunc func(string, []byte, time.Duration) error + +// loadFunc performs a load from a persistent store using a +// string key and returning the stored value as []byte +type loadFunc func(string) ([]byte, error) + +// clearFunc performs a persistent store's clear functionality using +// a string key for the target of the deletion. +type clearFunc func(string) error + +// ticket is a structure representing the ticket used in server based +// session storage. It provides a unique per session decryption secret giving +// more security than the shared CookieSecret. +type ticket struct { + id string + secret []byte + options *options.Cookie +} + +// newTicket creates a new ticket. The ID & secret will be randomly created +// with 16 byte sizes. The ID will be prefixed & hex encoded. +func newTicket(cookieOpts *options.Cookie) (*ticket, error) { + rawID := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, rawID); err != nil { + return nil, fmt.Errorf("failed to create new ticket ID: %v", err) + } + // ticketID is hex encoded + ticketID := fmt.Sprintf("%s-%s", cookieOpts.Name, hex.EncodeToString(rawID)) + + secret := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, secret); err != nil { + return nil, fmt.Errorf("failed to create encryption secret: %v", err) + } + + return &ticket{ + id: ticketID, + secret: secret, + options: cookieOpts, + }, nil +} + +// encodeTicket encodes the Ticket to a string for usage in cookies +func (t *ticket) encodeTicket() string { + return fmt.Sprintf("%s.%s", t.id, base64.RawURLEncoding.EncodeToString(t.secret)) +} + +// decodeTicket decodes an encoded ticket string +func decodeTicket(encTicket string, cookieOpts *options.Cookie) (*ticket, error) { + ticketParts := strings.Split(encTicket, ".") + if len(ticketParts) != 2 { + return nil, errors.New("failed to decode ticket") + } + ticketID, secretBase64 := ticketParts[0], ticketParts[1] + + secret, err := base64.RawURLEncoding.DecodeString(secretBase64) + if err != nil { + return nil, fmt.Errorf("failed to decode encryption secret: %v", err) + } + + return &ticket{ + id: ticketID, + secret: secret, + options: cookieOpts, + }, nil +} + +// decodeTicketFromRequest retrieves a potential ticket cookie from a request +// and decodes it to a ticket. +func decodeTicketFromRequest(req *http.Request, cookieOpts *options.Cookie) (*ticket, error) { + requestCookie, err := req.Cookie(cookieOpts.Name) + if err != nil { + // Don't wrap this error to allow `err == http.ErrNoCookie` checks + return nil, err + } + + // An existing cookie exists, try to retrieve the ticket + val, _, ok := encryption.Validate(requestCookie, cookieOpts.Secret, cookieOpts.Expire) + if !ok { + return nil, fmt.Errorf("session ticket cookie failed validation: %v", err) + } + + // Valid cookie, decode the ticket + return decodeTicket(string(val), cookieOpts) +} + +// saveSession encodes the SessionState with the ticket's secret and persists +// it to disk via the passed saveFunc. +func (t *ticket) saveSession(s *sessions.SessionState, saver saveFunc) error { + c, err := t.makeCipher() + if err != nil { + return err + } + ciphertext, err := s.EncodeSessionState(c, false) + if err != nil { + return fmt.Errorf("failed to encode the session state with the ticket: %v", err) + } + return saver(t.id, ciphertext, t.options.Expire) +} + +// loadSession loads a session from the disk store via the passed loadFunc +// using the ticket.id as the key. It then decodes the SessionState using +// ticket.secret to make the AES-GCM cipher. +// +// TODO (@NickMeves): Remove legacyV5LoadSession support in V7 +func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) { + ciphertext, err := loader(t.id) + if err != nil { + return nil, fmt.Errorf("failed to load the session state with the ticket: %v", err) + } + c, err := t.makeCipher() + if err != nil { + return nil, err + } + ss, err := sessions.DecodeSessionState(ciphertext, c, false) + if err != nil { + return t.legacyV5LoadSession(ciphertext) + } + return ss, nil +} + +// clearSession uses the passed clearFunc to delete a session stored with a +// key of ticket.id +func (t *ticket) clearSession(clearer clearFunc) error { + return clearer(t.id) +} + +// setCookie sets the encoded ticket as a cookie +func (t *ticket) setCookie(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) { + ticketCookie := t.makeCookie( + req, + t.encodeTicket(), + t.options.Expire, + *s.CreatedAt, + ) + + http.SetCookie(rw, ticketCookie) +} + +// clearCookie removes any cookies that would be where this ticket +// would set them +func (t *ticket) clearCookie(rw http.ResponseWriter, req *http.Request) { + clearCookie := t.makeCookie( + req, + "", + time.Hour*-1, + time.Now(), + ) + http.SetCookie(rw, clearCookie) +} + +// makeCookie makes a cookie, signing the value if present +func (t *ticket) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie { + if value != "" { + value = encryption.SignedValue(t.options.Secret, t.options.Name, []byte(value), now) + } + return cookies.MakeCookieFromOptions( + req, + t.options.Name, + value, + t.options, + expires, + now, + ) +} + +// makeCipher makes a AES-GCM cipher out of the ticket's secret +func (t *ticket) makeCipher() (encryption.Cipher, error) { + c, err := encryption.NewGCMCipher(t.secret) + if err != nil { + return nil, fmt.Errorf("failed to make an AES-GCM cipher from the ticket secret: %v", err) + } + return c, nil +} + +// legacyV5LoadSession loads a Redis session created in V5 with historical logic +// +// TODO (@NickMeves): Remove in V7 +func (t *ticket) legacyV5LoadSession(resultBytes []byte) (*sessions.SessionState, error) { + block, err := aes.NewCipher(t.secret) + if err != nil { + return nil, fmt.Errorf("failed to create a legacy AES-CFB cipher from the ticket secret: %v", err) + } + + stream := cipher.NewCFBDecrypter(block, t.secret) + stream.XORKeyStream(resultBytes, resultBytes) + + cfbCipher, err := encryption.NewCFBCipher(encryption.SecretBytes(t.options.Secret)) + if err != nil { + return nil, err + } + legacyCipher := encryption.NewBase64Cipher(cfbCipher) + + session, err := sessions.LegacyV5DecodeSessionState(string(resultBytes), legacyCipher) + if err != nil { + return nil, err + } + return session, nil +} diff --git a/pkg/sessions/persistence/ticket_test.go b/pkg/sessions/persistence/ticket_test.go new file mode 100644 index 00000000..821117a3 --- /dev/null +++ b/pkg/sessions/persistence/ticket_test.go @@ -0,0 +1,223 @@ +package persistence + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" + "testing" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +func Test_ticket(t *testing.T) { + logger.SetOutput(GinkgoWriter) + RegisterFailHandler(Fail) + RunSpecs(t, "Session Ticket") +} + +var _ = Describe("Session Ticket Tests", func() { + Context("encodeTicket & decodeTicket", func() { + type ticketTableInput struct { + ticket *ticket + encodedTicket string + expectedError error + } + + DescribeTable("encodeTicket should decodeTicket back when valid", + func(in ticketTableInput) { + if in.ticket != nil { + enc := in.ticket.encodeTicket() + Expect(enc).To(Equal(in.encodedTicket)) + + dec, err := decodeTicket(enc, in.ticket.options) + Expect(err).ToNot(HaveOccurred()) + Expect(dec).To(Equal(in.ticket)) + } else { + _, err := decodeTicket(in.encodedTicket, nil) + Expect(err).To(MatchError(in.expectedError)) + } + }, + Entry("with a valid ticket", ticketTableInput{ + ticket: &ticket{ + id: "dummy-0123456789abcdef", + secret: []byte("0123456789abcdef"), + options: &options.Cookie{ + Name: "dummy", + }, + }, + encodedTicket: fmt.Sprintf("%s.%s", + "dummy-0123456789abcdef", + base64.RawURLEncoding.EncodeToString([]byte("0123456789abcdef"))), + expectedError: nil, + }), + Entry("with an invalid encoded ticket with 1 part", ticketTableInput{ + ticket: nil, + encodedTicket: "dummy-0123456789abcdef", + expectedError: errors.New("failed to decode ticket"), + }), + Entry("with an invalid base64 encoded secret", ticketTableInput{ + ticket: nil, + encodedTicket: "dummy-0123456789abcdef.@)#($*@)#(*$@)#(*$", + expectedError: fmt.Errorf("failed to decode encryption secret: illegal base64 data at input byte 0"), + }), + ) + }) + + Context("saveSession", func() { + It("uses the passed save function", func() { + t, err := newTicket(&options.Cookie{Name: "dummy"}) + Expect(err).ToNot(HaveOccurred()) + + c, err := t.makeCipher() + Expect(err).ToNot(HaveOccurred()) + + ss := &sessions.SessionState{User: "foobar"} + store := map[string][]byte{} + err = t.saveSession(ss, func(k string, v []byte, e time.Duration) error { + store[k] = v + return nil + }) + Expect(err).ToNot(HaveOccurred()) + + stored, err := sessions.DecodeSessionState(store[t.id], c, false) + Expect(err).ToNot(HaveOccurred()) + Expect(stored).To(Equal(ss)) + }) + + It("errors when the saveFunc errors", func() { + t, err := newTicket(&options.Cookie{Name: "dummy"}) + Expect(err).ToNot(HaveOccurred()) + + err = t.saveSession( + &sessions.SessionState{User: "foobar"}, + func(k string, v []byte, e time.Duration) error { + return errors.New("save error") + }) + Expect(err).To(MatchError(errors.New("save error"))) + }) + }) + + Context("loadSession", func() { + It("uses the passed load function", func() { + t, err := newTicket(&options.Cookie{Name: "dummy"}) + Expect(err).ToNot(HaveOccurred()) + + c, err := t.makeCipher() + Expect(err).ToNot(HaveOccurred()) + + ss := &sessions.SessionState{User: "foobar"} + loadedSession, err := t.loadSession(func(k string) ([]byte, error) { + return ss.EncodeSessionState(c, false) + }) + Expect(err).ToNot(HaveOccurred()) + Expect(loadedSession).To(Equal(ss)) + }) + + It("errors when the loadFunc errors", func() { + t, err := newTicket(&options.Cookie{Name: "dummy"}) + Expect(err).ToNot(HaveOccurred()) + + data, err := t.loadSession(func(k string) ([]byte, error) { + return nil, errors.New("load error") + }) + Expect(data).To(BeNil()) + Expect(err).To(MatchError(errors.New("failed to load the session state with the ticket: load error"))) + }) + }) + + Context("clearSession", func() { + It("uses the passed clear function", func() { + t, err := newTicket(&options.Cookie{Name: "dummy"}) + Expect(err).ToNot(HaveOccurred()) + + var tracker string + err = t.clearSession(func(k string) error { + tracker = k + return nil + }) + Expect(err).ToNot(HaveOccurred()) + Expect(tracker).To(Equal(t.id)) + }) + + It("errors when the clearFunc errors", func() { + t, err := newTicket(&options.Cookie{Name: "dummy"}) + Expect(err).ToNot(HaveOccurred()) + + err = t.clearSession(func(k string) error { + return errors.New("clear error") + }) + Expect(err).To(MatchError(errors.New("clear error"))) + }) + }) +}) + +// TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession +// when a V5 encoded session is in Redis +// +// TODO (@NickMeves): Remove when this is deprecated (likely V7) +func Test_legacyV5LoadSession(t *testing.T) { + testCases, _, _ := sessions.CreateLegacyV5TestCases(t) + + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + g := NewWithT(t) + + secret := make([]byte, aes.BlockSize) + _, err := io.ReadFull(rand.Reader, secret) + g.Expect(err).ToNot(HaveOccurred()) + tckt := &ticket{ + secret: secret, + options: &options.Cookie{ + Secret: base64.RawURLEncoding.EncodeToString([]byte(sessions.LegacyV5TestSecret)), + }, + } + + encrypted, err := legacyStoreValue(tc.Input, tckt.secret) + g.Expect(err).ToNot(HaveOccurred()) + + ss, err := tckt.legacyV5LoadSession(encrypted) + if tc.Error { + g.Expect(err).To(HaveOccurred()) + g.Expect(ss).To(BeNil()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + // Compare sessions without *time.Time fields + exp := *tc.Output + exp.CreatedAt = nil + exp.ExpiresOn = nil + act := *ss + act.CreatedAt = nil + act.ExpiresOn = nil + g.Expect(exp).To(Equal(act)) + }) + } +} + +// legacyStoreValue implements the legacy V5 Redis persistence AES-CFB value encryption +// +// TODO (@NickMeves): Remove when this is deprecated (likely V7) +func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) { + ciphertext := make([]byte, len(value)) + block, err := aes.NewCipher(ticketSecret) + if err != nil { + return nil, fmt.Errorf("error initiating cipher block: %v", err) + } + + // Use secret as the Initialization Vector too, because each entry has it's own key + stream := cipher.NewCFBEncrypter(block, ticketSecret) + stream.XORKeyStream(ciphertext, []byte(value)) + + return ciphertext, nil +} diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index 0e0d7cd9..38fe3caa 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -2,91 +2,113 @@ package redis import ( "context" - "crypto/aes" - "crypto/cipher" - "crypto/rand" "crypto/x509" - "encoding/base64" - "encoding/hex" "fmt" - "io" "io/ioutil" - "net/http" - "strings" "time" "github.com/go-redis/redis/v7" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies" - "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence" ) -// TicketData is a structure representing the ticket used in server session storage -type TicketData struct { - TicketID string - Secret []byte -} - -// SessionStore is an implementation of the sessions.SessionStore +// SessionStore is an implementation of the persistence.Store // interface that stores sessions in redis type SessionStore struct { - CookieCipher encryption.Cipher - Cookie *options.Cookie - Client Client + Client Client } -// NewRedisSessionStore initialises a new instance of the SessionStore from -// the configuration given +// NewRedisSessionStore initialises a new instance of the SessionStore and wraps +// it in a persistence.Manager func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) { - cfbCipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret)) - if err != nil { - return nil, fmt.Errorf("error initialising cipher: %v", err) - } - - client, err := newRedisCmdable(opts.Redis) + client, err := newRedisClient(opts.Redis) if err != nil { return nil, fmt.Errorf("error constructing redis client: %v", err) } rs := &SessionStore{ - Client: client, - CookieCipher: cfbCipher, - Cookie: cookieOpts, + Client: client, } - return rs, nil - + return persistence.NewManager(rs, cookieOpts), nil } -func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) { +// Save takes a sessions.SessionState and stores the information from it +// to redies, and adds a new persistence cookie on the HTTP response writer +func (store *SessionStore) Save(ctx context.Context, key string, value []byte, exp time.Duration) error { + err := store.Client.Set(ctx, key, value, exp) + if err != nil { + return fmt.Errorf("error saving redis session: %v", err) + } + return nil +} + +// Load reads sessions.SessionState information from a persistence +// cookie within the HTTP request object +func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error) { + value, err := store.Client.Get(ctx, key) + if err != nil { + return nil, fmt.Errorf("error loading redis session: %v", err) + } + return value, nil +} + +// Clear clears any saved session information for a given persistence cookie +// from redis, and then clears the session +func (store *SessionStore) Clear(ctx context.Context, key string) error { + err := store.Client.Del(ctx, key) + if err != nil { + return fmt.Errorf("error clearing the session from redis: %v", err) + } + return nil +} + +// newRedisClient makes a redis.Client (either standalone, sentinel aware, or +// redis cluster) +func newRedisClient(opts options.RedisStoreOptions) (Client, error) { if opts.UseSentinel && opts.UseCluster { return nil, fmt.Errorf("options redis-use-sentinel and redis-use-cluster are mutually exclusive") } - if opts.UseSentinel { - addrs, err := parseRedisURLs(opts.SentinelConnectionURLs) - if err != nil { - return nil, fmt.Errorf("could not parse redis urls: %v", err) - } - client := redis.NewFailoverClient(&redis.FailoverOptions{ - MasterName: opts.SentinelMasterName, - SentinelAddrs: addrs, - }) - return newClient(client), nil + return buildSentinelClient(opts) } - if opts.UseCluster { - addrs, err := parseRedisURLs(opts.ClusterConnectionURLs) - if err != nil { - return nil, fmt.Errorf("could not parse redis urls: %v", err) - } - client := redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: addrs, - }) - return newClusterClient(client), nil + return buildClusterClient(opts) } + return buildStandaloneClient(opts) +} + +// buildSentinelClient makes a redis.Client that connects to Redis Sentinel +// for Primary/Replica Redis node coordination +func buildSentinelClient(opts options.RedisStoreOptions) (Client, error) { + addrs, err := parseRedisURLs(opts.SentinelConnectionURLs) + if err != nil { + return nil, fmt.Errorf("could not parse redis urls: %v", err) + } + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: opts.SentinelMasterName, + SentinelAddrs: addrs, + }) + return newClient(client), nil +} + +// buildClusterClient makes a redis.Client that is Redis Cluster aware +func buildClusterClient(opts options.RedisStoreOptions) (Client, error) { + addrs, err := parseRedisURLs(opts.ClusterConnectionURLs) + if err != nil { + return nil, fmt.Errorf("could not parse redis urls: %v", err) + } + client := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: addrs, + }) + return newClusterClient(client), nil +} + +// buildStandaloneClient makes a redis.Client that connects to a simple +// Redis node +func buildStandaloneClient(opts options.RedisStoreOptions) (Client, error) { opt, err := redis.ParseURL(opts.ConnectionURL) if err != nil { return nil, fmt.Errorf("unable to parse redis url: %s", err) @@ -134,261 +156,3 @@ func parseRedisURLs(urls []string) ([]string, error) { } return addrs, nil } - -// Save takes a sessions.SessionState and stores the information from it -// to redies, and adds a new ticket cookie on the HTTP response writer -func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { - if s.CreatedAt == nil || s.CreatedAt.IsZero() { - now := time.Now() - s.CreatedAt = &now - } - - // Old sessions that we are refreshing would have a request cookie - // New sessions don't, so we ignore the error. storeValue will check requestCookie - requestCookie, _ := req.Cookie(store.Cookie.Name) - ctx := req.Context() - ticketString, err := store.saveSession(ctx, s, store.Cookie.Expire, requestCookie) - if err != nil { - return err - } - - ticketCookie := store.makeCookie( - req, - ticketString, - store.Cookie.Expire, - *s.CreatedAt, - ) - - http.SetCookie(rw, ticketCookie) - return nil -} - -// Load reads sessions.SessionState information from a ticket -// cookie within the HTTP request object -func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { - requestCookie, err := req.Cookie(store.Cookie.Name) - if err != nil { - return nil, fmt.Errorf("error loading session: %s", err) - } - - val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire) - if !ok { - return nil, fmt.Errorf("cookie signature not valid") - } - ctx := req.Context() - session, err := store.loadSessionFromTicket(ctx, string(val)) - if err != nil { - return nil, fmt.Errorf("error loading session: %s", err) - } - return session, nil -} - -// Clear clears any saved session information for a given ticket cookie -// from redis, and then clears the session -func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { - // We go ahead and clear the cookie first, always. - clearCookie := store.makeCookie( - req, - "", - time.Hour*-1, - time.Now(), - ) - http.SetCookie(rw, clearCookie) - - // If there was an existing cookie we should clear the session in redis - requestCookie, err := req.Cookie(store.Cookie.Name) - if err != nil && err == http.ErrNoCookie { - // No existing cookie so can't clear redis - return nil - } else if err != nil { - return fmt.Errorf("error retrieving cookie: %v", err) - } - - val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire) - if !ok { - return fmt.Errorf("cookie signature not valid") - } - - // We only return an error if we had an issue with redis - // If there's an issue decoding the ticket, ignore it - ticket, _ := decodeTicket(store.Cookie.Name, string(val)) - if ticket != nil { - ctx := req.Context() - err := store.Client.Del(ctx, ticket.asHandle(store.Cookie.Name)) - if err != nil { - return fmt.Errorf("error clearing cookie from redis: %s", err) - } - } - return nil -} - -// saveSession encodes a session with a GCM cipher & saves the data into Redis -func (store *SessionStore) saveSession(ctx context.Context, s *sessions.SessionState, expiration time.Duration, requestCookie *http.Cookie) (string, error) { - ticket, err := store.getTicket(requestCookie) - if err != nil { - return "", fmt.Errorf("error getting ticket: %v", err) - } - - c, err := encryption.NewGCMCipher(ticket.Secret) - if err != nil { - return "", fmt.Errorf("error initiating cipher block %s", err) - } - - // Use AES-GCM since it provides authenticated encryption - // AES-CFB used in cookies has the cookie signing SHA to get around the lack of - // authentication in AES-CFB - ciphertext, err := s.EncodeSessionState(c, false) - if err != nil { - return "", err - } - - handle := ticket.asHandle(store.Cookie.Name) - err = store.Client.Set(ctx, handle, ciphertext, expiration) - if err != nil { - return "", err - } - return ticket.encodeTicket(store.Cookie.Name), nil -} - -// loadSessionFromTicket loads the session based on the ticket value -func (store *SessionStore) loadSessionFromTicket(ctx context.Context, value string) (*sessions.SessionState, error) { - ticket, err := decodeTicket(store.Cookie.Name, value) - if err != nil { - return nil, err - } - - resultBytes, err := store.Client.Get(ctx, ticket.asHandle(store.Cookie.Name)) - if err != nil { - return nil, err - } - - c, err := encryption.NewGCMCipher(ticket.Secret) - if err != nil { - return nil, err - } - - session, err := sessions.DecodeSessionState(resultBytes, c, false) - if err != nil { - // The GCM cipher will error due to a legacy JSON payload not passing - // the authentication check part of AES GCM encryption. - // In that case, we can attempt to fallback to try a legacy load - legacyCipher := encryption.NewBase64Cipher(store.CookieCipher) - return legacyV5DecodeSession(resultBytes, ticket, legacyCipher) - } - return session, nil -} - -// legacyV5DecodeSession loads the session based on the ticket value -// This fallback uses V5 style encryption of Base64 + AES CFB -func legacyV5DecodeSession(resultBytes []byte, ticket *TicketData, c encryption.Cipher) (*sessions.SessionState, error) { - block, err := aes.NewCipher(ticket.Secret) - if err != nil { - return nil, err - } - // Use secret as the IV too, because each entry has it's own key - stream := cipher.NewCFBDecrypter(block, ticket.Secret) - stream.XORKeyStream(resultBytes, resultBytes) - - session, err := sessions.LegacyV5DecodeSessionState(string(resultBytes), c) - if err != nil { - return nil, err - } - return session, nil -} - -// makeCookie makes a cookie, signing the value if present -func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie { - if value != "" { - value = encryption.SignedValue(store.Cookie.Secret, store.Cookie.Name, []byte(value), now) - } - return cookies.MakeCookieFromOptions( - req, - store.Cookie.Name, - value, - store.Cookie, - expires, - now, - ) -} - -// getTicket retrieves an existing ticket from the cookie if present, -// or creates a new ticket -func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, error) { - if requestCookie == nil { - return newTicket() - } - - // An existing cookie exists, try to retrieve the ticket - val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire) - if !ok { - // Cookie is invalid, create a new ticket - return newTicket() - } - - // Valid cookie, decode the ticket - ticket, err := decodeTicket(store.Cookie.Name, string(val)) - if err != nil { - // If we can't decode the ticket we have to create a new one - return newTicket() - } - return ticket, nil -} - -func newTicket() (*TicketData, error) { - rawID := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, rawID); err != nil { - return nil, fmt.Errorf("failed to create new ticket ID %s", err) - } - // ticketID is hex encoded - ticketID := hex.EncodeToString(rawID) - - secret := make([]byte, aes.BlockSize) - if _, err := io.ReadFull(rand.Reader, secret); err != nil { - return nil, fmt.Errorf("failed to create initialization vector %s", err) - } - ticket := &TicketData{ - TicketID: ticketID, - Secret: secret, - } - return ticket, nil -} - -func (ticket *TicketData) asHandle(prefix string) string { - return fmt.Sprintf("%s-%s", prefix, ticket.TicketID) -} - -func decodeTicket(cookieName string, ticketString string) (*TicketData, error) { - prefix := cookieName + "-" - if !strings.HasPrefix(ticketString, prefix) { - return nil, fmt.Errorf("failed to decode ticket handle") - } - trimmedTicket := strings.TrimPrefix(ticketString, prefix) - - ticketParts := strings.Split(trimmedTicket, ".") - if len(ticketParts) != 2 { - return nil, fmt.Errorf("failed to decode ticket") - } - ticketID, secretBase64 := ticketParts[0], ticketParts[1] - - // ticketID must be a hexadecimal string - _, err := hex.DecodeString(ticketID) - if err != nil { - return nil, fmt.Errorf("server ticket failed sanity checks") - } - - secret, err := base64.RawURLEncoding.DecodeString(secretBase64) - if err != nil { - return nil, fmt.Errorf("failed to decode initialization vector %s", err) - } - ticketData := &TicketData{ - TicketID: ticketID, - Secret: secret, - } - return ticketData, nil -} - -func (ticket *TicketData) encodeTicket(prefix string) string { - handle := ticket.asHandle(prefix) - ticketString := handle + "." + base64.RawURLEncoding.EncodeToString(ticket.Secret) - return ticketString -} diff --git a/pkg/sessions/redis/redis_store_test.go b/pkg/sessions/redis/redis_store_test.go index 12965705..e5adc875 100644 --- a/pkg/sessions/redis/redis_store_test.go +++ b/pkg/sessions/redis/redis_store_test.go @@ -1,11 +1,6 @@ package redis import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "fmt" - "io" "log" "os" "testing" @@ -17,70 +12,12 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -// TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession -// when a V5 encoded session is in Redis -// -// TODO: Remove when this is deprecated (likely V7) -func Test_legacyV5DecodeSession(t *testing.T) { - testCases, _, legacyCipher := sessionsapi.CreateLegacyV5TestCases(t) - - for testName, tc := range testCases { - t.Run(testName, func(t *testing.T) { - g := NewWithT(t) - - secret := make([]byte, aes.BlockSize) - _, err := io.ReadFull(rand.Reader, secret) - g.Expect(err).ToNot(HaveOccurred()) - ticket := &TicketData{ - TicketID: "", - Secret: secret, - } - - encrypted, err := legacyStoreValue(tc.Input, ticket.Secret) - g.Expect(err).ToNot(HaveOccurred()) - - ss, err := legacyV5DecodeSession(encrypted, ticket, legacyCipher) - if tc.Error { - g.Expect(err).To(HaveOccurred()) - g.Expect(ss).To(BeNil()) - return - } - g.Expect(err).ToNot(HaveOccurred()) - - // Compare sessions without *time.Time fields - exp := *tc.Output - exp.CreatedAt = nil - exp.ExpiresOn = nil - act := *ss - act.CreatedAt = nil - act.ExpiresOn = nil - g.Expect(exp).To(Equal(act)) - }) - } -} - -// legacyStoreValue implements the legacy V5 Redis store AES-CFB value encryption -// -// TODO: Remove when this is deprecated (likely V7) -func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) { - ciphertext := make([]byte, len(value)) - block, err := aes.NewCipher(ticketSecret) - if err != nil { - return nil, fmt.Errorf("error initiating cipher block: %v", err) - } - - // Use secret as the Initialization Vector too, because each entry has it's own key - stream := cipher.NewCFBEncrypter(block, ticketSecret) - stream.XORKeyStream(ciphertext, []byte(value)) - - return ciphertext, nil -} - func TestSessionStore(t *testing.T) { logger.SetOutput(GinkgoWriter) @@ -114,9 +51,9 @@ var _ = Describe("Redis SessionStore Tests", func() { JustAfterEach(func() { // Release any connections immediately after the test ends - if redisStore, ok := ss.(*SessionStore); ok { - if redisStore.Client != nil { - Expect(redisStore.Client.(closer).Close()).To(Succeed()) + if redisManager, ok := ss.(*persistence.Manager); ok { + if redisManager.Store.(*SessionStore).Client != nil { + Expect(redisManager.Store.(*SessionStore).Client.(closer).Close()).To(Succeed()) } } }) diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 37c35888..2f8d5992 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -10,6 +10,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" + "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/redis" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -66,10 +67,11 @@ var _ = Describe("NewSessionStore", func() { opts.Redis.ConnectionURL = "redis://" }) - It("creates a redis.SessionStore", func() { + It("creates a persistence.Manager that wraps a redis.SessionStore", func() { ss, err := sessions.NewSessionStore(opts, cookieOpts) Expect(err).NotTo(HaveOccurred()) - Expect(ss).To(BeAssignableToTypeOf(&redis.SessionStore{})) + Expect(ss).To(BeAssignableToTypeOf(&persistence.Manager{})) + Expect(ss.(*persistence.Manager).Store).To(BeAssignableToTypeOf(&redis.SessionStore{})) }) }) diff --git a/pkg/sessions/tests/mock_store.go b/pkg/sessions/tests/mock_store.go new file mode 100644 index 00000000..1ddca76e --- /dev/null +++ b/pkg/sessions/tests/mock_store.go @@ -0,0 +1,58 @@ +package tests + +import ( + "context" + "fmt" + "time" +) + +// entry is a MockStore cache entry with an expiration +type entry struct { + data []byte + expiration time.Duration +} + +// MockStore is a generic in-memory implementation of persistence.Store +// for mocking in tests +type MockStore struct { + cache map[string]entry + elapsed time.Duration +} + +// NewMockStore creates a MockStore +func NewMockStore() *MockStore { + return &MockStore{ + cache: map[string]entry{}, + elapsed: 0 * time.Second, + } +} + +// Save sets a key to the data to the memory cache +func (s *MockStore) Save(_ context.Context, key string, value []byte, exp time.Duration) error { + s.cache[key] = entry{ + data: value, + expiration: exp, + } + return nil +} + +// Load gets data from the memory cache via a key +func (s *MockStore) Load(_ context.Context, key string) ([]byte, error) { + entry, ok := s.cache[key] + if !ok || entry.expiration <= s.elapsed { + delete(s.cache, key) + return nil, fmt.Errorf("key not found: %s", key) + } + return entry.data, nil +} + +// Clear deletes an entry from the memory cache +func (s *MockStore) Clear(_ context.Context, key string) error { + delete(s.cache, key) + return nil +} + +// FastForward simulates the flow of time to test expirations +func (s *MockStore) FastForward(duration time.Duration) { + s.elapsed += duration +} diff --git a/pkg/sessions/tests/session_store_tests.go b/pkg/sessions/tests/session_store_tests.go index 8778efe8..cf3f8173 100644 --- a/pkg/sessions/tests/session_store_tests.go +++ b/pkg/sessions/tests/session_store_tests.go @@ -133,18 +133,6 @@ func RunSessionStoreTests(newSS NewSessionStoreFunc, persistentFastForward Persi PersistentSessionStoreInterfaceTests(&input) } }) - - Context("with an invalid cookie secret", func() { - BeforeEach(func() { - input.cookieOpts.Secret = "invalid" - }) - - It("returns an error when initialising the session store", func() { - ss, err := newSS(opts, input.cookieOpts) - Expect(err).To(MatchError("error initialising cipher: crypto/aes: invalid key size 7")) - Expect(ss).To(BeNil()) - }) - }) }) }