diff --git a/oauthproxy.go b/oauthproxy.go index 3efe66fd..871405c5 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -421,6 +421,7 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi RefreshPeriod: opts.Cookie.Refresh, RefreshSession: provider.RefreshSession, ValidateSession: provider.ValidateSession, + EnrichSession: provider.EnrichSession, })) return chain diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 72c364e7..98c6a909 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -48,6 +48,9 @@ type StoredSessionLoaderOptions struct { // If the sesssion is older than `RefreshPeriod` but the provider doesn't // refresh it, we must re-validate using this validation. ValidateSession func(context.Context, *sessionsapi.SessionState) bool + + // Provider based session enriching after a token refresh. + EnrichSession func(context.Context, *sessionsapi.SessionState) error } // NewStoredSessionLoader creates a new storedSessionLoader which loads @@ -60,6 +63,7 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor refreshPeriod: opts.RefreshPeriod, sessionRefresher: opts.RefreshSession, sessionValidator: opts.ValidateSession, + sessionEnricher: opts.EnrichSession, } return ss.loadSession } @@ -71,6 +75,7 @@ type storedSessionLoader struct { refreshPeriod time.Duration sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) sessionValidator func(context.Context, *sessionsapi.SessionState) bool + sessionEnricher func(context.Context, *sessionsapi.SessionState) error } // loadSession attempts to load a session as identified by the request cookies. @@ -230,6 +235,14 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R // (In case underlying provider implementations forget) session.CreatedAtNow() + // Re-enrich the session after a real token refresh so that providers + // which fetch extra data (e.g. groups via Graph API) can repopulate it. + if !errors.Is(err, providers.ErrNotImplemented) && s.sessionEnricher != nil { + if enrichErr := s.sessionEnricher(req.Context(), session); enrichErr != nil { + return fmt.Errorf("error enriching session after refresh: %v", enrichErr) + } + } + // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) if err != nil { diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index d8e78f2f..d3d25dfa 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -638,9 +638,11 @@ var _ = Describe("Stored Session Suite", func() { Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { - session *sessionsapi.SessionState - expectedErr error - expectSaved bool + session *sessionsapi.SessionState + expectedErr error + expectSaved bool + expectEnriched bool + sessionEnricher func(context.Context, *sessionsapi.SessionState) error } now := time.Now() @@ -648,6 +650,7 @@ var _ = Describe("Stored Session Suite", func() { DescribeTable("when refreshing with the provider", func(in refreshSessionWithProviderTableInput) { saved := false + enriched := false s := &storedSessionLoader{ store: &fakeSessionStore{ @@ -671,6 +674,13 @@ var _ = Describe("Stored Session Suite", func() { return false, errors.New("error refreshing session") } }, + sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error { + enriched = true + if in.sessionEnricher != nil { + return in.sessionEnricher(context.Background(), nil) + } + return nil + }, } req := httptest.NewRequest("", "/", nil) @@ -682,27 +692,31 @@ var _ = Describe("Stored Session Suite", func() { Expect(err).ToNot(HaveOccurred()) } Expect(saved).To(Equal(in.expectSaved)) + Expect(enriched).To(Equal(in.expectEnriched)) }, Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: noRefresh, }, - expectedErr: nil, - expectSaved: false, + expectedErr: nil, + expectSaved: false, + expectEnriched: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, }, - expectedErr: nil, - expectSaved: true, + expectedErr: nil, + expectSaved: true, + expectEnriched: true, }), Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: notImplemented, }, - expectedErr: nil, - expectSaved: true, + expectedErr: nil, + expectSaved: true, + expectEnriched: false, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ @@ -710,16 +724,29 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &now, ExpiresOn: &now, }, - expectedErr: errors.New("error refreshing tokens: error refreshing session"), - expectSaved: false, + expectedErr: errors.New("error refreshing tokens: error refreshing session"), + expectSaved: false, + expectEnriched: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, AccessToken: "NoSave", }, - expectedErr: errors.New("error saving session: unable to save session"), - expectSaved: true, + expectedErr: errors.New("error saving session: unable to save session"), + expectSaved: true, + expectEnriched: true, + }), + Entry("when enriching the session returns an error", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + }, + sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error { + return errors.New("enrich error") + }, + expectedErr: errors.New("error enriching session after refresh: enrich error"), + expectSaved: false, + expectEnriched: true, }), ) })