From 7e80e5596b64ed9d25c9e2e119c7025da104f941 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 6 Mar 2021 15:33:13 -0800 Subject: [PATCH] RefreshSessions immediately when called --- oauthproxy.go | 8 ++--- pkg/middleware/stored_session.go | 49 ++++++++++++++------------- pkg/middleware/stored_session_test.go | 25 ++++++-------- providers/azure_test.go | 7 ++-- providers/facebook.go | 2 +- providers/gitlab.go | 21 +++++------- providers/google.go | 7 ++-- providers/linkedin.go | 2 +- providers/oidc.go | 7 ++-- providers/oidc_test.go | 4 +-- providers/provider_default.go | 13 ++++--- providers/provider_default_test.go | 2 +- providers/providers.go | 6 ++-- 13 files changed, 74 insertions(+), 79 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index d6479609..b0c94eb0 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -361,10 +361,10 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt } chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ - SessionStore: sessionStore, - RefreshPeriod: opts.Cookie.Refresh, - RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, - ValidateSessionState: opts.GetProvider().ValidateSession, + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSession: opts.GetProvider().RefreshSession, + ValidateSession: opts.GetProvider().ValidateSession, })) return chain diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 1bd0a9a4..b3737581 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -24,12 +24,12 @@ type StoredSessionLoaderOptions struct { RefreshPeriod time.Duration // Provider based sesssion refreshing - RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) + RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) // Provider based session validation. // If the sesssion is older than `RefreshPeriod` but the provider doesn't // refresh it, we must re-validate using this validation. - ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool + ValidateSession func(context.Context, *sessionsapi.SessionState) bool } // NewStoredSessionLoader creates a new storedSessionLoader which loads @@ -38,10 +38,10 @@ type StoredSessionLoaderOptions struct { // If a session was loader by a previous handler, it will not be replaced. func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { ss := &storedSessionLoader{ - store: opts.SessionStore, - refreshPeriod: opts.RefreshPeriod, - refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, - validateSessionState: opts.ValidateSessionState, + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + sessionRefresher: opts.RefreshSession, + sessionValidator: opts.ValidateSession, } return ss.loadSession } @@ -49,10 +49,10 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor // storedSessionLoader is responsible for loading sessions from cookie // identified sessions in the session store. type storedSessionLoader struct { - store sessionsapi.SessionStore - refreshPeriod time.Duration - refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) - validateSessionState func(context.Context, *sessionsapi.SessionState) bool + store sessionsapi.SessionStore + refreshPeriod time.Duration + sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) + sessionValidator func(context.Context, *sessionsapi.SessionState) bool } // loadSession attempts to load a session as identified by the request cookies. @@ -120,37 +120,38 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req } logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) - refreshed, err := s.refreshSessionWithProvider(rw, req, session) + err := s.refreshSession(rw, req, session) if err != nil { return err } - if !refreshed { - // Session wasn't refreshed, so make sure it's still valid - return s.validateSession(req.Context(), session) - } - return nil + // Validate all sessions after any Redeem/Refresh operation + return s.validateSession(req.Context(), session) } -// refreshSessionWithProvider attempts to refresh the sessinon with the provider +// refreshSession attempts to refresh the session with the provider // and will save the session if it was updated. -func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { - refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) +func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { + refreshed, err := s.sessionRefresher(req.Context(), session) if err != nil { - return false, fmt.Errorf("error refreshing access token: %v", err) + return fmt.Errorf("error refreshing access token: %v", err) } if !refreshed { - return false, nil + return nil } + // If we refreshed, update the `CreatedAt` time to reset the refresh timer + // TODO: Implement + // session.CreatedAtNow() + // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) - return false, fmt.Errorf("error saving session: %v", err) + return fmt.Errorf("error saving session: %v", err) } - return true, nil + return nil } // validateSession checks whether the session has expired and performs @@ -161,7 +162,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess return errors.New("session is expired") } - if !s.validateSessionState(ctx, session) { + if !s.sessionValidator(ctx, session) { return errors.New("session is invalid") } diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 3d8dd087..2ec134c9 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -109,10 +109,10 @@ var _ = Describe("Stored Session Suite", func() { rw := httptest.NewRecorder() opts := &StoredSessionLoaderOptions{ - SessionStore: in.store, - RefreshPeriod: in.refreshPeriod, - RefreshSessionIfNeeded: in.refreshSession, - ValidateSessionState: in.validateSession, + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: in.refreshSession, + ValidateSession: in.validateSession, } // Create the handler with a next handler that will capture the session @@ -261,7 +261,7 @@ var _ = Describe("Stored Session Suite", func() { s := &storedSessionLoader{ refreshPeriod: in.refreshPeriod, store: &fakeSessionStore{}, - refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken { case refresh: @@ -272,7 +272,7 @@ var _ = Describe("Stored Session Suite", func() { return false, errors.New("error refreshing session") } }, - validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { validated = true return ss.AccessToken != "Invalid" }, @@ -364,7 +364,7 @@ var _ = Describe("Stored Session Suite", func() { ) }) - Context("refreshSessionWithProvider", func() { + Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { session *sessionsapi.SessionState expectedErr error @@ -388,7 +388,7 @@ var _ = Describe("Stored Session Suite", func() { return nil }, }, - refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { + sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: return true, nil @@ -402,13 +402,12 @@ var _ = Describe("Stored Session Suite", func() { req := httptest.NewRequest("", "/", nil) req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) - refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) + err := s.refreshSession(nil, req, in.session) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else { Expect(err).ToNot(HaveOccurred()) } - Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(saved).To(Equal(in.expectSaved)) }, Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ @@ -416,7 +415,6 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, }, expectedErr: nil, - expectRefreshed: false, expectSaved: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ @@ -424,7 +422,6 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: refresh, }, expectedErr: nil, - expectRefreshed: true, expectSaved: true, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ @@ -434,7 +431,6 @@ var _ = Describe("Stored Session Suite", func() { ExpiresOn: &now, }, expectedErr: errors.New("error refreshing access token: error refreshing session"), - expectRefreshed: false, expectSaved: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ @@ -443,7 +439,6 @@ var _ = Describe("Stored Session Suite", func() { AccessToken: "NoSave", }, expectedErr: errors.New("error saving session: unable to save session"), - expectRefreshed: false, expectSaved: true, }), ) @@ -454,7 +449,7 @@ var _ = Describe("Stored Session Suite", func() { BeforeEach(func() { s = &storedSessionLoader{ - validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { + sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { return ss.AccessToken == "Valid" }, } diff --git a/providers/azure_test.go b/providers/azure_test.go index bb44f20a..9d539c63 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -345,7 +345,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { expires := time.Now().Add(time.Duration(1) * time.Hour) session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) + refreshNeeded, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) assert.False(t, refreshNeeded) } @@ -373,9 +373,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) { expires := time.Now().Add(time.Duration(-1) * time.Hour) session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) + + refreshed, err := p.RefreshSession(context.Background(), session) assert.Equal(t, nil, err) - assert.True(t, refreshNeeded) + assert.True(t, refreshed) assert.NotEqual(t, session, nil) assert.Equal(t, "new_some_access_token", session.AccessToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken) diff --git a/providers/facebook.go b/providers/facebook.go index e3babc0d..6db9c38d 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess return r.Email, nil } -// ValidateSessionState validates the AccessToken +// ValidateSession validates the AccessToken func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) } diff --git a/providers/gitlab.go b/providers/gitlab.go index 18f77fe7..ca9a8bf2 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() { } } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } @@ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions return true, nil } -func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { +func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { clientSecret, err := p.GetClientSecret() if err != nil { - return + return err } c := oauth2.Config{ @@ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses if err != nil { return fmt.Errorf("unable to update session: %v", err) } - s.AccessToken = newSession.AccessToken - s.IDToken = newSession.IDToken - s.RefreshToken = newSession.RefreshToken - s.CreatedAt = newSession.CreatedAt - s.ExpiresOn = newSession.ExpiresOn - s.Email = newSession.Email - return + *s = *newSession + + return nil } type gitlabUserInfo struct { diff --git a/providers/google.go b/providers/google.go index b669156d..49eae1c1 100644 --- a/providers/google.go +++ b/providers/google.go @@ -266,10 +266,9 @@ func userInGroup(service *admin.Service, group string, email string) bool { return false } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required -func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } diff --git a/providers/linkedin.go b/providers/linkedin.go index 58217952..115d4c99 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess return email, nil } -// ValidateSessionState validates the AccessToken +// ValidateSession validates the AccessToken func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) } diff --git a/providers/oidc.go b/providers/oidc.go index 9e7bf56f..3e1e79a8 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS return true } -// RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new Access Token (and optional ID token) if required -func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { +// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens +func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshToken == "" { return false, nil } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 7fae3368..9879678d 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { User: "11223344", } - refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) + refreshed, err := provider.RefreshSession(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, "janedoe@example.com", existingSession.Email) @@ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { Email: "changeit", User: "changeit", } - refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) + refreshed, err := provider.RefreshSession(context.Background(), existingSession) assert.Equal(t, nil, err) assert.Equal(t, refreshed, true) assert.Equal(t, defaultIDToken.Email, existingSession.Email) diff --git a/providers/provider_default.go b/providers/provider_default.go index 1bde54b7..be57f0e5 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -126,10 +126,15 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS return validateToken(ctx, p, s.AccessToken, nil) } -// RefreshSessionIfNeeded should refresh the user's session if required and -// do nothing if a refresh is not required -func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { - return false, nil +// RefreshSession refreshes the user's session +func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) { + if s == nil { + return false, nil + } + + // Pretend `RefreshSession` occured so `ValidateSession` isn't called + // on every request after any potential set refresh period elapses. + return true, nil } // CreateSessionFromToken converts Bearer IDTokens into sessions diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 0bd2f4f0..8474baae 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -15,7 +15,7 @@ func TestRefresh(t *testing.T) { p := &ProviderData{} expires := time.Now().Add(time.Duration(-11) * time.Minute) - refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ + refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{ ExpiresOn: &expires, }) assert.Equal(t, false, refreshed) diff --git a/providers/providers.go b/providers/providers.go index 0340c420..d21409c2 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -9,14 +9,14 @@ import ( // Provider represents an upstream identity provider implementation type Provider interface { Data() *ProviderData + GetLoginURL(redirectURI, finalRedirect string, nonce string) string + Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) // Deprecated: Migrate to EnrichSession GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) - GetLoginURL(redirectURI, state, nonce string) string - Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) EnrichSession(ctx context.Context, s *sessions.SessionState) error Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) ValidateSession(ctx context.Context, s *sessions.SessionState) bool - RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) + RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) }