From 7e80e5596b64ed9d25c9e2e119c7025da104f941 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 6 Mar 2021 15:33:13 -0800 Subject: [PATCH 1/6] RefreshSessions immediately when called --- oauthproxy.go | 8 ++--- pkg/middleware/stored_session.go | 49 ++++++++++++++------------- pkg/middleware/stored_session_test.go | 25 ++++++-------- providers/azure_test.go | 7 ++-- providers/facebook.go | 2 +- providers/gitlab.go | 21 +++++------- providers/google.go | 7 ++-- providers/linkedin.go | 2 +- providers/oidc.go | 7 ++-- providers/oidc_test.go | 4 +-- providers/provider_default.go | 13 ++++--- providers/provider_default_test.go | 2 +- providers/providers.go | 6 ++-- 13 files changed, 74 insertions(+), 79 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index d6479609..b0c94eb0 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -361,10 +361,10 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt } chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ - SessionStore: sessionStore, - RefreshPeriod: opts.Cookie.Refresh, - RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, - ValidateSessionState: opts.GetProvider().ValidateSession, + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSession: opts.GetProvider().RefreshSession, + ValidateSession: opts.GetProvider().ValidateSession, })) return chain diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 1bd0a9a4..b3737581 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -24,12 +24,12 @@ type StoredSessionLoaderOptions struct { RefreshPeriod time.Duration // Provider based sesssion refreshing - RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) + RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) // Provider based session validation. // If the sesssion is older than `RefreshPeriod` but the provider doesn't // refresh it, we must re-validate using this validation. - ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool + ValidateSession func(context.Context, *sessionsapi.SessionState) bool } // NewStoredSessionLoader creates a new storedSessionLoader which loads @@ -38,10 +38,10 @@ type StoredSessionLoaderOptions struct { // If a session was loader by a previous handler, it will not be replaced. func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { ss := &storedSessionLoader{ - store: opts.SessionStore, - refreshPeriod: opts.RefreshPeriod, - refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, - validateSessionState: opts.ValidateSessionState, + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + sessionRefresher: opts.RefreshSession, + sessionValidator: opts.ValidateSession, } return ss.loadSession } @@ -49,10 +49,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor // storedSessionLoader is responsible for loading sessions from cookie // identified sessions in the session store. type storedSessionLoader struct { - store sessionsapi.SessionStore - refreshPeriod time.Duration - refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) - validateSessionState func(context.Context, *sessionsapi.SessionState) bool + store sessionsapi.SessionStore + refreshPeriod time.Duration + sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) + sessionValidator func(context.Context, *sessionsapi.SessionState) bool } // loadSession attempts to load a session as identified by the request cookies. @@ -120,37 +120,38 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req } logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) - refreshed, err := s.refreshSessionWithProvider(rw, req, session) + err := s.refreshSession(rw, req, session) if err != nil { return err } - if !refreshed { - // Session wasn't refreshed, so make sure it's still valid - return s.validateSession(req.Context(), session) - } - return nil + // Validate all sessions after any Redeem/Refresh operation + return s.validateSession(req.Context(), session) } -// refreshSessionWithProvider attempts to refresh the sessinon with the provider +// refreshSession attempts to refresh the session with the provider // and will save the session if it was updated. -func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { - refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) +func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { + refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil { - return false, fmt.Errorf("error refreshing access token: %v", err) + return fmt.Errorf("error refreshing access token: %v", err) } if !refreshed { - return false, nil + return nil } + // If we refreshed, update the `CreatedAt` time to reset the refresh timer + // TODO: Implement + // session.CreatedAtNow() + // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) - return false, fmt.Errorf("error saving session: %v", err) + return fmt.Errorf("error saving session: %v", err) } - return true, nil + return nil } // validateSession checks whether the session has expired and performs @@ -161,7 +162,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess return errors.New("session is expired") } - if !s.validateSessionState(ctx, session) { + if !s.sessionValidator(ctx, session) { return errors.New("session is invalid") } diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 3d8dd087..2ec134c9 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -109,10 +109,10 @@ var _ = Describe("Stored Session Suite", func() { rw := httptest.NewRecorder() opts := &StoredSessionLoaderOptions{ - SessionStore: in.store, - RefreshPeriod: in.refreshPeriod, - RefreshSessionIfNeeded: in.refreshSession, - ValidateSessionState: in.validateSession, + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: in.refreshSession, + ValidateSession: in.validateSession, } // Create the handler with a next handler that will capture the session @@ -261,7 +261,7 @@ var _ = Describe("Stored Session Suite", func() { s := &storedSessionLoader{ refreshPeriod: in.refreshPeriod, store: &fakeSessionStore{}, - refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken { case refresh: @@ -272,7 +272,7 @@ var _ = Describe("Stored Session Suite", func() { return false, errors.New("error refreshing session") } }, - validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { validated = true return ss.AccessToken != "Invalid" }, @@ -364,7 +364,7 @@ var _ = Describe("Stored Session Suite", func() { ) }) - Context("refreshSessionWithProvider", func() { + Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { session *sessionsapi.SessionState expectedErr error @@ -388,7 +388,7 @@ var _ = Describe("Stored Session Suite", func() { return nil }, }, - refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: return true, nil @@ -402,13 +402,12 @@ var _ = Describe("Stored Session Suite", func() { req := httptest.NewRequest("", "/", nil) req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) - refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) + err := s.refreshSession(nil, req, in.session) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else { Expect(err).ToNot(HaveOccurred()) } - Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(saved).To(Equal(in.expectSaved)) }, Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ @@ -416,7 +415,6 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, }, expectedErr: nil, - expectRefreshed: false, expectSaved: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ @@ -424,7 +422,6 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: refresh, }, expectedErr: nil, - expectRefreshed: true, expectSaved: true, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ @@ -434,7 +431,6 @@ var _ = Describe("Stored Session Suite", func() { ExpiresOn: &now, }, expectedErr: errors.New("error refreshing access token: error refreshing session"), - expectRefreshed: false, expectSaved: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ @@ -443,7 +439,6 @@ var _ = Describe("Stored Session Suite", func() { AccessToken: "NoSave", }, expectedErr: errors.New("error saving session: unable to save session"), - expectRefreshed: false, expectSaved: true, }), ) @@ -454,7 +449,7 @@ var _ = Describe("Stored Session Suite", func() { BeforeEach(func() { s = &storedSessionLoader{ - validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { return ss.AccessToken == "Valid" }, } diff --git a/providers/azure_test.go b/providers/azure_test.go index bb44f20a..9d539c63 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -345,7 +345,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { expires := time.Now().Add(time.Duration(1) * time.Hour) session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) + refreshNeeded, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) assert.False(t, refreshNeeded) } @@ -373,9 +373,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) { expires := time.Now().Add(time.Duration(-1) * time.Hour) session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) + + refreshed, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) - assert.True(t, refreshNeeded) + assert.True(t, refreshed) assert.NotEqual(t, session, nil) assert.Equal(t, "new_some_access_token", session.AccessToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken) diff --git a/providers/facebook.go b/providers/facebook.go index e3babc0d..6db9c38d 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess return r.Email, nil } -// ValidateSessionState validates the AccessToken +// ValidateSession validates the AccessToken func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) } diff --git a/providers/gitlab.go b/providers/gitlab.go index 18f77fe7..ca9a8bf2 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() { } } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } @@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions return true, nil } -func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { +func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { clientSecret, err := p.GetClientSecret() if err != nil { - return + return err } c := oauth2.Config{ @@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses if err != nil { return fmt.Errorf("unable to update session: %v", err) } - s.AccessToken = newSession.AccessToken - s.IDToken = newSession.IDToken - s.RefreshToken = newSession.RefreshToken - s.CreatedAt = newSession.CreatedAt - s.ExpiresOn = newSession.ExpiresOn - s.Email = newSession.Email - return + *s = *newSession + + return nil } type gitlabUserInfo struct { diff --git a/providers/google.go b/providers/google.go index b669156d..49eae1c1 100644 --- a/providers/google.go +++ b/providers/google.go @@ -266,10 +266,9 @@ func userInGroup(service *admin.Service, group string, email string) bool { return false } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } diff --git a/providers/linkedin.go b/providers/linkedin.go index 58217952..115d4c99 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess return email, nil } -// ValidateSessionState validates the AccessToken +// ValidateSession validates the AccessToken func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) } diff --git a/providers/oidc.go b/providers/oidc.go index 9e7bf56f..3e1e79a8 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS return true } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new Access Token (and optional ID token) if required -func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 7fae3368..9879678d 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { User: "11223344", } - refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) + refreshed, err := provider.RefreshSession(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, "janedoe@example.com", existingSession.Email) @@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { Email: "changeit", User: "changeit", } - refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) + refreshed, err := provider.RefreshSession(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, defaultIDToken.Email, existingSession.Email) diff --git a/providers/provider_default.go b/providers/provider_default.go index 1bde54b7..be57f0e5 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -126,10 +126,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS return validateToken(ctx, p, s.AccessToken, nil) } -// RefreshSessionIfNeeded should refresh the user's session if required and -// do nothing if a refresh is not required -func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { - return false, nil +// RefreshSession refreshes the user's session +func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) { + if s == nil { + return false, nil + } + + // Pretend `RefreshSession` occured so `ValidateSession` isn't called + // on every request after any potential set refresh period elapses. + return true, nil } // CreateSessionFromToken converts Bearer IDTokens into sessions diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 0bd2f4f0..8474baae 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) { p := &ProviderData{} expires := time.Now().Add(time.Duration(-11) * time.Minute) - refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ + refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{ ExpiresOn: &expires, }) assert.Equal(t, false, refreshed) diff --git a/providers/providers.go b/providers/providers.go index 0340c420..d21409c2 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -9,14 +9,14 @@ import ( // Provider represents an upstream identity provider implementation type Provider interface { Data() *ProviderData + GetLoginURL(redirectURI, finalRedirect string, nonce string) string + Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) // Deprecated: Migrate to EnrichSession GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) - GetLoginURL(redirectURI, state, nonce string) string - Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) EnrichSession(ctx context.Context, s *sessions.SessionState) error Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) ValidateSession(ctx context.Context, s *sessions.SessionState) bool - RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) + RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) } From 7fa6d2d024b4a46d477bc1a4da58e377ee60fbce Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 6 Mar 2021 15:33:40 -0800 Subject: [PATCH 2/6] Manage session time fields centrally --- oauthproxy.go | 15 +++++++++++--- pkg/apis/sessions/session_state.go | 29 ++++++++++++++++++++++++--- pkg/middleware/stored_session.go | 3 +-- pkg/sessions/cookie/session_store.go | 3 +-- pkg/sessions/persistence/manager.go | 3 +-- providers/azure.go | 30 +++++++++++----------------- providers/gitlab.go | 12 ++++++----- providers/google.go | 25 +++++++++++------------ providers/logingov.go | 17 ++++++++-------- providers/oidc.go | 9 +++++---- providers/provider_default.go | 9 ++++++--- 11 files changed, 91 insertions(+), 64 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index b0c94eb0..c3a5693d 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e if err != nil { return nil, err } + + // Force setting these in case the Provider didn't + if s.CreatedAt == nil { + s.CreatedAtNow() + } + if s.ExpiresOn == nil { + s.ExpiresIn(p.CookieOptions.Expire) + } + return s, nil } @@ -861,9 +870,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en var noCacheHeaders = map[string]string{ - "Expires": time.Unix(0, 0).Format(time.RFC1123), - "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", - "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ + "Expires": time.Unix(0, 0).Format(time.RFC1123), + "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", + "X-Accel-Expire": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ } // prepareNoCache prepares headers for preventing browser caching. diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index e1ee4a6c..9e77609e 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -11,6 +11,7 @@ import ( "time" "unicode/utf8" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/pierrec/lz4" "github.com/vmihailenco/msgpack/v4" @@ -32,7 +33,8 @@ type SessionState struct { Groups []string `msgpack:"g,omitempty"` PreferredUsername string `msgpack:"pu,omitempty"` - Lock Lock `msgpack:"-"` + Clock clock.Clock `msgpack:"-"` + Lock Lock `msgpack:"-"` } func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error { @@ -63,9 +65,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) { return s.Lock.Peek(ctx) } +// CreatedAtNow sets a SessionState's CreatedAt to now +func (s *SessionState) CreatedAtNow() { + now := s.Clock.Now() + s.CreatedAt = &now +} + +// SetExpiresOn sets an expiration +func (s *SessionState) SetExpiresOn(exp time.Time) { + s.ExpiresOn = &exp +} + +// ExpiresIn sets an expiration a certain duration from CreatedAt. +// CreatedAt will be set to time.Now if it is unset. +func (s *SessionState) ExpiresIn(d time.Duration) { + if s.CreatedAt == nil { + s.CreatedAtNow() + } + exp := s.CreatedAt.Add(d) + s.ExpiresOn = &exp +} + // IsExpired checks whether the session has expired func (s *SessionState) IsExpired() bool { - if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { + if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) { return true } return false @@ -74,7 +97,7 @@ func (s *SessionState) IsExpired() bool { // Age returns the age of a session func (s *SessionState) Age() time.Duration { if s.CreatedAt != nil && !s.CreatedAt.IsZero() { - return time.Now().Truncate(time.Second).Sub(*s.CreatedAt) + return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt) } return 0 } diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index b3737581..9f69ba64 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -142,8 +142,7 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R } // If we refreshed, update the `CreatedAt` time to reset the refresh timer - // TODO: Implement - // session.CreatedAtNow() + session.CreatedAtNow() // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index ce51ed07..1b3c12de 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -36,8 +36,7 @@ type SessionStore struct { // within Cookies set on the HTTP response writer func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { - now := time.Now() - ss.CreatedAt = &now + ss.CreatedAtNow() } value, err := s.cookieForSession(ss) if err != nil { diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go index 49225171..3215b257 100644 --- a/pkg/sessions/persistence/manager.go +++ b/pkg/sessions/persistence/manager.go @@ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager { // from the persistent data store. func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { if s.CreatedAt == nil || s.CreatedAt.IsZero() { - now := time.Now() - s.CreatedAt = &now + s.CreatedAtNow() } tckt, err := decodeTicketFromRequest(req, m.Options) diff --git a/providers/azure.go b/providers/azure.go index f66d3764..46d7e302 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* return nil, err } - created := time.Now() - expires := time.Unix(jsonResponse.ExpiresOn, 0) - session := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, } + session.CreatedAtNow() + session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) @@ -239,10 +236,9 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st return email, nil } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } @@ -257,7 +253,7 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions. return true, nil } -func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { +func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { params := url.Values{} params.Add("client_id", p.ClientID) params.Add("client_secret", p.ClientSecret) @@ -271,25 +267,23 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess IDToken string `json:"id_token"` } - err = requests.New(p.RedeemURL.String()). + err := requests.New(p.RedeemURL.String()). WithContext(ctx). WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). Do(). UnmarshalInto(&jsonResponse) - if err != nil { - return + return err } - now := time.Now() - expires := time.Unix(jsonResponse.ExpiresOn, 0) s.AccessToken = jsonResponse.AccessToken s.IDToken = jsonResponse.IDToken s.RefreshToken = jsonResponse.RefreshToken - s.CreatedAt = &now - s.ExpiresOn = &expires + + s.CreatedAtNow() + s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) @@ -312,7 +306,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess } } - return + return nil } func makeAzureHeader(accessToken string) http.Header { diff --git a/providers/gitlab.go b/providers/gitlab.go index ca9a8bf2..a2b11df7 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -259,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) } } - created := time.Now() - return &sessions.SessionState{ + ss := &sessions.SessionState{ AccessToken: token.AccessToken, IDToken: getIDToken(token), RefreshToken: token.RefreshToken, - CreatedAt: &created, - ExpiresOn: &idToken.Expiry, - }, nil + } + + ss.CreatedAtNow() + ss.SetExpiresOn(idToken.Expiry) + + return ss, nil } // ValidateSession checks that the session's IDToken is still valid diff --git a/providers/google.go b/providers/google.go index 49eae1c1..0cfd3e1c 100644 --- a/providers/google.go +++ b/providers/google.go @@ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( return nil, err } - created := time.Now() - expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) - - return &sessions.SessionState{ + ss := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, Email: c.Email, User: c.Subject, - }, nil + } + ss.CreatedAtNow() + ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) + + return ss, nil } // EnrichSession checks the listed Google Groups configured and adds any // that the user is a member of to session.Groups. -func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { +func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error { // TODO (@NickMeves) - Move to pure EnrichSession logic and stop // reusing legacy `groupValidator`. // @@ -272,7 +271,7 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session return false, nil } - newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken) + newToken, newIDToken, ttl, err := p.redeemRefreshToken(ctx, s.RefreshToken) if err != nil { return false, err } @@ -285,12 +284,12 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) } - origExpiration := s.ExpiresOn - expires := time.Now().Add(duration).Truncate(time.Second) s.AccessToken = newToken s.IDToken = newIDToken - s.ExpiresOn = &expires - logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration) + + s.CreatedAtNow() + s.ExpiresIn(ttl) + return true, nil } diff --git a/providers/logingov.go b/providers/logingov.go index 0f625208..43f361f3 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { +func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) { if code == "" { return nil, ErrMissingCode } @@ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) return nil, err } - created := time.Now() - expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) - - // Store the data that we found in the session state - return &sessions.SessionState{ + session := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: &created, - ExpiresOn: &expires, Email: email, - }, nil + } + + session.CreatedAtNow() + session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) + + return session, nil } // GetLoginURL overrides GetLoginURL to add login.gov parameters diff --git a/providers/oidc.go b/providers/oidc.go index 3e1e79a8..2cbbd009 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -226,7 +226,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) ss.AccessToken = token ss.IDToken = token ss.RefreshToken = "" - ss.ExpiresOn = &idToken.Expiry + + ss.CreatedAtNow() + ss.SetExpiresOn(idToken.Expiry) return ss, nil } @@ -256,9 +258,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r ss.RefreshToken = token.RefreshToken ss.IDToken = getIDToken(token) - created := time.Now() - ss.CreatedAt = &created - ss.ExpiresOn = &token.Expiry + ss.CreatedAtNow() + ss.SetExpiresOn(token.Expiry) return ss, nil } diff --git a/providers/provider_default.go b/providers/provider_default.go index be57f0e5..0a62c240 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net/url" - "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" @@ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s if err != nil { return nil, err } + // TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration if token := values.Get("access_token"); token != "" { - created := time.Now() - return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil + ss := &sessions.SessionState{ + AccessToken: token, + } + ss.CreatedAtNow() + return ss, nil } return nil, fmt.Errorf("no access token found %s", result.Body()) From 593125152daa9e48ca34a8a8b021a3fcb234a13c Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 6 Mar 2021 15:48:31 -0800 Subject: [PATCH 3/6] Standarize provider refresh implemention & logging --- oauthproxy.go | 6 +-- pkg/apis/sessions/session_state_test.go | 24 ++++++++++ pkg/middleware/stored_session.go | 20 ++++---- pkg/middleware/stored_session_test.go | 64 +++++++++++++++++-------- providers/azure.go | 12 +++-- providers/azure_test.go | 12 +---- providers/google.go | 27 +++++------ providers/oidc.go | 1 - providers/provider_default.go | 7 ++- providers/provider_default_test.go | 20 +++++--- 10 files changed, 123 insertions(+), 70 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index c3a5693d..e2d20ed6 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -870,9 +870,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en var noCacheHeaders = map[string]string{ - "Expires": time.Unix(0, 0).Format(time.RFC1123), - "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", - "X-Accel-Expire": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ + "Expires": time.Unix(0, 0).Format(time.RFC1123), + "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", + "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ } // prepareNoCache prepares headers for preventing browser caching. diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 0121b4b6..e6b9ff39 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -16,6 +16,30 @@ func timePtr(t time.Time) *time.Time { return &t } +func TestCreatedAtNow(t *testing.T) { + g := NewWithT(t) + ss := &SessionState{} + + now := time.Unix(1234567890, 0) + ss.Clock.Set(now) + + ss.CreatedAtNow() + g.Expect(*ss.CreatedAt).To(Equal(now)) +} + +func TestExpiresIn(t *testing.T) { + g := NewWithT(t) + ss := &SessionState{} + + now := time.Unix(1234567890, 0) + ss.Clock.Set(now) + + ttl := time.Duration(743) * time.Second + ss.ExpiresIn(ttl) + + g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl))) +} + func TestString(t *testing.T) { g := NewWithT(t) created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z") diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 9f69ba64..85974867 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -108,11 +108,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h // refreshSessionIfNeeded will attempt to refresh a session if the session // is older than the refresh period. -// It is assumed that if the provider refreshes the session, the session is now -// valid. -// If the session requires refreshing but the provider does not refresh it, -// we must validate the session to ensure that the returned session is still -// valid. +// Success or fail, we will then validate the session. func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { // Refresh is disabled or the session is not old enough, do nothing @@ -122,10 +118,12 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) err := s.refreshSession(rw, req, session) if err != nil { - return err + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.Errorf("Unable to refresh session: %v", err) } - // Validate all sessions after any Redeem/Refresh operation + // Validate all sessions after any Redeem/Refresh operation (fail or success) return s.validateSession(req.Context(), session) } @@ -134,7 +132,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil { - return fmt.Errorf("error refreshing access token: %v", err) + return fmt.Errorf("error refreshing tokens: %v", err) } if !refreshed { @@ -142,6 +140,12 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R } // If we refreshed, update the `CreatedAt` time to reset the refresh timer + // + // HACK: + // Providers that don't implement `RefreshSession` use the default + // implementation. It always returns `refreshed == true`, so the + // `session.CreatedAt` is updated and doesn't trigger `ValidateSession` + // every subsequent request. session.CreatedAtNow() // Because the session was refreshed, make sure to save it diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 2ec134c9..9c9a4b92 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -10,6 +10,7 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -24,8 +25,9 @@ var _ = Describe("Stored Session Suite", func() { var ctx = context.Background() Context("StoredSessionLoader", func() { - createdPast := time.Now().Add(-5 * time.Minute) - createdFuture := time.Now().Add(5 * time.Minute) + now := time.Now() + createdPast := now.Add(-5 * time.Minute) + createdFuture := now.Add(5 * time.Minute) var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { @@ -85,6 +87,14 @@ var _ = Describe("Stored Session Suite", func() { }, } + BeforeEach(func() { + clock.Set(now) + }) + + AfterEach(func() { + clock.Reset() + }) + type storedSessionLoaderTableInput struct { requestHeaders http.Header existingSession *sessionsapi.SessionState @@ -208,6 +218,21 @@ var _ = Describe("Stored Session Suite", func() { existingSession: nil, expectedSession: &sessionsapi.SessionState{ RefreshToken: "Refreshed", + CreatedAt: &now, + ExpiresOn: &createdFuture, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, @@ -216,7 +241,7 @@ var _ = Describe("Stored Session Suite", func() { refreshSession: defaultRefreshFunc, validateSession: defaultValidateFunc, }), - Entry("when the provider refresh fails", storedSessionLoaderTableInput{ + Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{ requestHeaders: http.Header{ "Cookie": []string{"_oauth2_proxy=RefreshError"}, }, @@ -225,7 +250,7 @@ var _ = Describe("Stored Session Suite", func() { store: defaultSessionStore, refreshPeriod: 1 * time.Minute, refreshSession: defaultRefreshFunc, - validateSession: defaultValidateFunc, + validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false }, }), Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ requestHeaders: http.Header{ @@ -326,7 +351,7 @@ var _ = Describe("Stored Session Suite", func() { }, expectedErr: nil, expectRefreshed: true, - expectValidated: false, + expectValidated: true, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -339,15 +364,15 @@ var _ = Describe("Stored Session Suite", func() { expectRefreshed: true, expectValidated: true, }), - Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ + Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: "RefreshError", CreatedAt: &createdPast, }, - expectedErr: errors.New("error refreshing access token: error refreshing session"), + expectedErr: nil, expectRefreshed: true, - expectValidated: false, + expectValidated: true, }), Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -366,10 +391,9 @@ var _ = Describe("Stored Session Suite", func() { Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { - session *sessionsapi.SessionState - expectedErr error - expectRefreshed bool - expectSaved bool + session *sessionsapi.SessionState + expectedErr error + expectSaved bool } now := time.Now() @@ -414,15 +438,15 @@ var _ = Describe("Stored Session Suite", func() { session: &sessionsapi.SessionState{ RefreshToken: noRefresh, }, - expectedErr: nil, - expectSaved: false, + expectedErr: nil, + expectSaved: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, }, - expectedErr: nil, - expectSaved: true, + expectedErr: nil, + expectSaved: true, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ @@ -430,16 +454,16 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &now, ExpiresOn: &now, }, - expectedErr: errors.New("error refreshing access token: error refreshing session"), - expectSaved: false, + expectedErr: errors.New("error refreshing tokens: error refreshing session"), + expectSaved: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, AccessToken: "NoSave", }, - expectedErr: errors.New("error saving session: unable to save session"), - expectSaved: true, + expectedErr: errors.New("error saving session: unable to save session"), + expectSaved: true, }), ) }) diff --git a/providers/azure.go b/providers/azure.go index 46d7e302..39beb836 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -242,21 +242,23 @@ func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionS return false, nil } - origExpiration := s.ExpiresOn - err := p.redeemRefreshToken(ctx, s) if err != nil { return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) return true, nil } func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { + clientSecret, err := p.GetClientSecret() + if err != nil { + return err + } + params := url.Values{} params.Add("client_id", p.ClientID) - params.Add("client_secret", p.ClientSecret) + params.Add("client_secret", clientSecret) params.Add("refresh_token", s.RefreshToken) params.Add("grant_type", "refresh_token") @@ -267,7 +269,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess IDToken string `json:"id_token"` } - err := requests.New(p.RedeemURL.String()). + err = requests.New(p.RedeemURL.String()). WithContext(ctx). WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). diff --git a/providers/azure_test.go b/providers/azure_test.go index 9d539c63..78592df7 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) } -func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { - p := testAzureProvider("") - - expires := time.Now().Add(time.Duration(1) * time.Hour) - session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSession(context.Background(), session) - assert.Equal(t, nil, err) - assert.False(t, refreshNeeded) -} - -func TestAzureProviderRefreshWhenExpired(t *testing.T) { +func TestAzureProviderRefresh(t *testing.T) { email := "foo@example.com" idToken := idTokenClaims{Email: email} idTokenString, err := newSignedTestIDToken(idToken) diff --git a/providers/google.go b/providers/google.go index 0cfd3e1c..a467c50f 100644 --- a/providers/google.go +++ b/providers/google.go @@ -271,7 +271,7 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session return false, nil } - newToken, newIDToken, ttl, err := p.redeemRefreshToken(ctx, s.RefreshToken) + err := p.redeemRefreshToken(ctx, s) if err != nil { return false, err } @@ -284,26 +284,20 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) } - s.AccessToken = newToken - s.IDToken = newIDToken - - s.CreatedAtNow() - s.ExpiresIn(ttl) - return true, nil } -func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) { +func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh clientSecret, err := p.GetClientSecret() if err != nil { - return + return err } params := url.Values{} params.Add("client_id", p.ClientID) params.Add("client_secret", clientSecret) - params.Add("refresh_token", refreshToken) + params.Add("refresh_token", s.RefreshToken) params.Add("grant_type", "refresh_token") var data struct { @@ -320,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st Do(). UnmarshalInto(&data) if err != nil { - return "", "", 0, err + return err } - token = data.AccessToken - idToken = data.IDToken - expires = time.Duration(data.ExpiresIn) * time.Second - return + s.AccessToken = data.AccessToken + s.IDToken = data.IDToken + + s.CreatedAtNow() + s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second) + + return nil } diff --git a/providers/oidc.go b/providers/oidc.go index 2cbbd009..b1711d54 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -154,7 +154,6 @@ func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - logger.Printf("refreshed session: %s", s) return true, nil } diff --git a/providers/provider_default.go b/providers/provider_default.go index 0a62c240..d364501b 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -135,8 +135,13 @@ func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionStat return false, nil } - // Pretend `RefreshSession` occured so `ValidateSession` isn't called + // HACK: + // Pretend `RefreshSession` occurred so `ValidateSession` isn't called // on every request after any potential set refresh period elapses. + // See `middleware.refreshSession` for detailed logic & explanation. + // + // Intentionally doesn't use `ErrNotImplemented` since all providers will + // call this and we don't want to force them to implement this dummy logic. return true, nil } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 8474baae..2ba72f25 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -14,12 +14,20 @@ import ( func TestRefresh(t *testing.T) { p := &ProviderData{} - expires := time.Now().Add(time.Duration(-11) * time.Minute) - refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{ - ExpiresOn: &expires, - }) - assert.Equal(t, false, refreshed) - assert.Equal(t, nil, err) + now := time.Unix(1234567890, 10) + expires := time.Unix(1234567890, 0) + + ss := &sessions.SessionState{} + ss.Clock.Set(now) + ss.SetExpiresOn(expires) + + refreshed, err := p.RefreshSession(context.Background(), ss) + assert.True(t, refreshed) + assert.NoError(t, err) + + refreshed, err = p.RefreshSession(context.Background(), nil) + assert.False(t, refreshed) + assert.NoError(t, err) } func TestAcrValuesNotConfigured(t *testing.T) { From d91c3f867d91620929ab90b64913d48d3f9c0353 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 12 Jun 2021 11:18:19 -0700 Subject: [PATCH 4/6] Remove validation for invalid legacy v6.0.0 sessions The reflect.DeepCopy doesn't play nice with the new Lock and Clock fields in sessions. And it added unneeded session deserialization logic to every request. --- pkg/apis/sessions/session_state.go | 47 ++---------------------------- 1 file changed, 3 insertions(+), 44 deletions(-) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index 9e77609e..752d8afb 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -3,18 +3,14 @@ package sessions import ( "bytes" "context" - "errors" "fmt" - "io" - "io/ioutil" - "reflect" - "time" - "unicode/utf8" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/pierrec/lz4" "github.com/vmihailenco/msgpack/v4" + "io" + "io/ioutil" + "time" ) // SessionState is used to store information about the currently authenticated user session @@ -200,11 +196,6 @@ func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*Ses return nil, fmt.Errorf("error unmarshalling data to session state: %w", err) } - err = ss.validate() - if err != nil { - return nil, err - } - return &ss, nil } @@ -258,35 +249,3 @@ func lz4Decompress(compressed []byte) ([]byte, error) { return payload, nil } - -// validate ensures the decoded session is non-empty and contains valid data -// -// Non-empty check is needed due to ensure the non-authenticated AES-CFB -// decryption doesn't result in garbage data that collides with a valid -// MessagePack header bytes (which MessagePack will unmarshal to an empty -// default SessionState). <1% chance, but observed with random test data. -// -// UTF-8 check ensures the strings are valid and not raw bytes overloaded -// into Latin-1 encoding. The occurs when legacy unencrypted fields are -// decrypted with AES-CFB which results in random bytes. -func (s *SessionState) validate() error { - for _, field := range []string{ - s.User, - s.Email, - s.PreferredUsername, - s.AccessToken, - s.IDToken, - s.RefreshToken, - } { - if !utf8.ValidString(field) { - return errors.New("invalid non-UTF8 field in session") - } - } - - empty := new(SessionState) - if reflect.DeepEqual(*s, *empty) { - return errors.New("invalid empty session unmarshalled") - } - - return nil -} From baf6cf38165215f6a95a2b3d9646b04b0a87c336 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 12 Jun 2021 11:28:22 -0700 Subject: [PATCH 5/6] Remove mutex from local Clock instances They will only be used in tests, but it doesn't play nice with copy operations many tests use. The linter was not happy. While the global clock needs mutexes for parallelism, local Clocks only used it for Set/Add and didn't even use the mutex for actual time functions. --- pkg/apis/sessions/session_state.go | 8 +++++--- pkg/clock/clock.go | 7 ------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index 752d8afb..08538dae 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -4,13 +4,14 @@ import ( "bytes" "context" "fmt" + "io" + "io/ioutil" + "time" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/pierrec/lz4" "github.com/vmihailenco/msgpack/v4" - "io" - "io/ioutil" - "time" ) // SessionState is used to store information about the currently authenticated user session @@ -29,6 +30,7 @@ type SessionState struct { Groups []string `msgpack:"g,omitempty"` PreferredUsername string `msgpack:"pu,omitempty"` + // Internal helpers, not serialized Clock clock.Clock `msgpack:"-"` Lock Lock `msgpack:"-"` } diff --git a/pkg/clock/clock.go b/pkg/clock/clock.go index 34b7bf23..887bf0aa 100644 --- a/pkg/clock/clock.go +++ b/pkg/clock/clock.go @@ -63,13 +63,10 @@ func Reset() *clockapi.Mock { // package. type Clock struct { mock *clockapi.Mock - sync.Mutex } // Set sets the Clock to a clock.Mock at the given time.Time func (c *Clock) Set(t time.Time) { - c.Lock() - defer c.Unlock() if c.mock == nil { c.mock = clockapi.NewMock() } @@ -79,8 +76,6 @@ func (c *Clock) Set(t time.Time) { // Add moves clock forward time.Duration if it is mocked. It will error // if the clock is not mocked. func (c *Clock) Add(d time.Duration) error { - c.Lock() - defer c.Unlock() if c.mock == nil { return errors.New("clock not mocked") } @@ -91,8 +86,6 @@ func (c *Clock) Add(d time.Duration) error { // Reset removes local clock.Mock. Returns any existing Mock if set in case // lingering time operations are attached to it. func (c *Clock) Reset() *clockapi.Mock { - c.Lock() - defer c.Unlock() existing := c.mock c.mock = nil return existing From ff914d7e175566c59de93f88e4c05320f5367d07 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 12 Jun 2021 11:41:03 -0700 Subject: [PATCH 6/6] Use `ErrNotImplemented` in default refresh implementation --- CHANGELOG.md | 6 ++++++ pkg/middleware/stored_session.go | 29 +++++++++++++++++---------- pkg/middleware/stored_session_test.go | 27 +++++++++++++++++++++++-- providers/provider_default.go | 15 ++------------ providers/provider_default_test.go | 6 +++--- 5 files changed, 54 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2446269..55841190 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,16 @@ ## Important Notes +- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) The extra validation to protect invalid session + deserialization from v6.0.0 (only) has been removed to improve performance. If you are on v6.0.0, either upgrade + to a version before this first and allow legacy sessions to expire gracefully or change your `cookie-secret` + value and force all sessions to reauthenticate. + ## Breaking Changes ## Changes since v7.1.3 +- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) Refresh sessions before token expiration if configured (@NickMeves) - [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed) - [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed) - [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 85974867..6748816f 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -11,19 +11,20 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) -// StoredSessionLoaderOptions cotnains all of the requirements to construct +// StoredSessionLoaderOptions contains all of the requirements to construct // a stored session loader. // All options must be provided. type StoredSessionLoaderOptions struct { - // Session storage basckend + // Session storage backend SessionStore sessionsapi.SessionStore // How often should sessions be refreshed RefreshPeriod time.Duration - // Provider based sesssion refreshing + // Provider based session refreshing RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) // Provider based session validation. @@ -115,7 +116,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req return nil } - logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) err := s.refreshSession(rw, req, session) if err != nil { // If a preemptive refresh fails, we still keep the session @@ -131,21 +132,27 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req // and will save the session if it was updated. func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { refreshed, err := s.sessionRefresher(req.Context(), session) - if err != nil { + if err != nil && !errors.Is(err, providers.ErrNotImplemented) { return fmt.Errorf("error refreshing tokens: %v", err) } + // HACK: + // Providers that don't implement `RefreshSession` use the default + // implementation which returns `ErrNotImplemented`. + // Pretend it refreshed to reset the refresh timer so that `ValidateSession` + // isn't triggered every subsequent request and is only called once during + // this request. + if errors.Is(err, providers.ErrNotImplemented) { + refreshed = true + } + + // Session not refreshed, nothing to persist. if !refreshed { return nil } // If we refreshed, update the `CreatedAt` time to reset the refresh timer - // - // HACK: - // Providers that don't implement `RefreshSession` use the default - // implementation. It always returns `refreshed == true`, so the - // `session.CreatedAt` is updated and doesn't trigger `ValidateSession` - // every subsequent request. + // (In case underlying provider implementations forget) session.CreatedAtNow() // Because the session was refreshed, make sure to save it diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 9c9a4b92..782390b6 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -11,6 +11,7 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" + "github.com/oauth2-proxy/oauth2-proxy/v7/providers" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -18,8 +19,9 @@ import ( var _ = Describe("Stored Session Suite", func() { const ( - refresh = "Refresh" - noRefresh = "NoRefresh" + refresh = "Refresh" + noRefresh = "NoRefresh" + notImplemented = "NotImplemented" ) var ctx = context.Background() @@ -293,6 +295,8 @@ var _ = Describe("Stored Session Suite", func() { return true, nil case noRefresh: return false, nil + case notImplemented: + return false, providers.ErrNotImplemented default: return false, errors.New("error refreshing session") } @@ -364,6 +368,16 @@ var _ = Describe("Stored Session Suite", func() { expectRefreshed: true, expectValidated: true, }), + Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: notImplemented, + CreatedAt: &createdPast, + }, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + }), Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ @@ -418,6 +432,8 @@ var _ = Describe("Stored Session Suite", func() { return true, nil case noRefresh: return false, nil + case notImplemented: + return false, providers.ErrNotImplemented default: return false, errors.New("error refreshing session") } @@ -448,6 +464,13 @@ var _ = Describe("Stored Session Suite", func() { expectedErr: nil, expectSaved: true, }), + Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: notImplemented, + }, + expectedErr: nil, + expectSaved: true, + }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: "RefreshError", diff --git a/providers/provider_default.go b/providers/provider_default.go index d364501b..7a641b1e 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -130,19 +130,8 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS } // RefreshSession refreshes the user's session -func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) { - if s == nil { - return false, nil - } - - // HACK: - // Pretend `RefreshSession` occurred so `ValidateSession` isn't called - // on every request after any potential set refresh period elapses. - // See `middleware.refreshSession` for detailed logic & explanation. - // - // Intentionally doesn't use `ErrNotImplemented` since all providers will - // call this and we don't want to force them to implement this dummy logic. - return true, nil +func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) { + return false, ErrNotImplemented } // CreateSessionFromToken converts Bearer IDTokens into sessions diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 2ba72f25..5d4ed1af 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -22,12 +22,12 @@ func TestRefresh(t *testing.T) { ss.SetExpiresOn(expires) refreshed, err := p.RefreshSession(context.Background(), ss) - assert.True(t, refreshed) - assert.NoError(t, err) + assert.False(t, refreshed) + assert.Equal(t, ErrNotImplemented, err) refreshed, err = p.RefreshSession(context.Background(), nil) assert.False(t, refreshed) - assert.NoError(t, err) + assert.Equal(t, ErrNotImplemented, err) } func TestAcrValuesNotConfigured(t *testing.T) {