From 8d19ba11f85ab76ec396b0907c45db77133f5e6c Mon Sep 17 00:00:00 2001 From: ugurtafrali Date: Tue, 7 Apr 2026 22:01:56 +0300 Subject: [PATCH 1/2] call EnrichSession after session refresh Signed-off-by: ugurtafrali --- oauthproxy.go | 1 + pkg/middleware/stored_session.go | 13 +++++++ pkg/middleware/stored_session_test.go | 53 ++++++++++++++++++++------- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index 3efe66fd..871405c5 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -421,6 +421,7 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi RefreshPeriod: opts.Cookie.Refresh, RefreshSession: provider.RefreshSession, ValidateSession: provider.ValidateSession, + EnrichSession: provider.EnrichSession, })) return chain diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 72c364e7..98c6a909 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -48,6 +48,9 @@ type StoredSessionLoaderOptions struct { // If the sesssion is older than `RefreshPeriod` but the provider doesn't // refresh it, we must re-validate using this validation. ValidateSession func(context.Context, *sessionsapi.SessionState) bool + + // Provider based session enriching after a token refresh. + EnrichSession func(context.Context, *sessionsapi.SessionState) error } // NewStoredSessionLoader creates a new storedSessionLoader which loads @@ -60,6 +63,7 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor refreshPeriod: opts.RefreshPeriod, sessionRefresher: opts.RefreshSession, sessionValidator: opts.ValidateSession, + sessionEnricher: opts.EnrichSession, } return ss.loadSession } @@ -71,6 +75,7 @@ type storedSessionLoader struct { refreshPeriod time.Duration sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) sessionValidator func(context.Context, *sessionsapi.SessionState) bool + sessionEnricher func(context.Context, *sessionsapi.SessionState) error } // loadSession attempts to load a session as identified by the request cookies. @@ -230,6 +235,14 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R // (In case underlying provider implementations forget) session.CreatedAtNow() + // Re-enrich the session after a real token refresh so that providers + // which fetch extra data (e.g. groups via Graph API) can repopulate it. + if !errors.Is(err, providers.ErrNotImplemented) && s.sessionEnricher != nil { + if enrichErr := s.sessionEnricher(req.Context(), session); enrichErr != nil { + return fmt.Errorf("error enriching session after refresh: %v", enrichErr) + } + } + // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) if err != nil { diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index d8e78f2f..d3d25dfa 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -638,9 +638,11 @@ var _ = Describe("Stored Session Suite", func() { Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { - session *sessionsapi.SessionState - expectedErr error - expectSaved bool + session *sessionsapi.SessionState + expectedErr error + expectSaved bool + expectEnriched bool + sessionEnricher func(context.Context, *sessionsapi.SessionState) error } now := time.Now() @@ -648,6 +650,7 @@ var _ = Describe("Stored Session Suite", func() { DescribeTable("when refreshing with the provider", func(in refreshSessionWithProviderTableInput) { saved := false + enriched := false s := &storedSessionLoader{ store: &fakeSessionStore{ @@ -671,6 +674,13 @@ var _ = Describe("Stored Session Suite", func() { return false, errors.New("error refreshing session") } }, + sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error { + enriched = true + if in.sessionEnricher != nil { + return in.sessionEnricher(context.Background(), nil) + } + return nil + }, } req := httptest.NewRequest("", "/", nil) @@ -682,27 +692,31 @@ var _ = Describe("Stored Session Suite", func() { Expect(err).ToNot(HaveOccurred()) } Expect(saved).To(Equal(in.expectSaved)) + Expect(enriched).To(Equal(in.expectEnriched)) }, Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: noRefresh, }, - expectedErr: nil, - expectSaved: false, + expectedErr: nil, + expectSaved: false, + expectEnriched: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, }, - expectedErr: nil, - expectSaved: true, + expectedErr: nil, + expectSaved: true, + expectEnriched: true, }), Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: notImplemented, }, - expectedErr: nil, - expectSaved: true, + expectedErr: nil, + expectSaved: true, + expectEnriched: false, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ @@ -710,16 +724,29 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &now, ExpiresOn: &now, }, - expectedErr: errors.New("error refreshing tokens: error refreshing session"), - expectSaved: false, + expectedErr: errors.New("error refreshing tokens: error refreshing session"), + expectSaved: false, + expectEnriched: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, AccessToken: "NoSave", }, - expectedErr: errors.New("error saving session: unable to save session"), - expectSaved: true, + expectedErr: errors.New("error saving session: unable to save session"), + expectSaved: true, + expectEnriched: true, + }), + Entry("when enriching the session returns an error", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + }, + sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error { + return errors.New("enrich error") + }, + expectedErr: errors.New("error enriching session after refresh: enrich error"), + expectSaved: false, + expectEnriched: true, }), ) }) From 3e538674e85070cabe56f79d3e52528af80d96a8 Mon Sep 17 00:00:00 2001 From: ugurtafrali Date: Fri, 10 Apr 2026 00:42:19 +0300 Subject: [PATCH 2/2] Address review feedback --- oauthproxy.go | 1 - pkg/middleware/stored_session.go | 13 ------------- 2 files changed, 14 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index 871405c5..3efe66fd 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -421,7 +421,6 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi RefreshPeriod: opts.Cookie.Refresh, RefreshSession: provider.RefreshSession, ValidateSession: provider.ValidateSession, - EnrichSession: provider.EnrichSession, })) return chain diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 98c6a909..72c364e7 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -48,9 +48,6 @@ type StoredSessionLoaderOptions struct { // If the sesssion is older than `RefreshPeriod` but the provider doesn't // refresh it, we must re-validate using this validation. ValidateSession func(context.Context, *sessionsapi.SessionState) bool - - // Provider based session enriching after a token refresh. - EnrichSession func(context.Context, *sessionsapi.SessionState) error } // NewStoredSessionLoader creates a new storedSessionLoader which loads @@ -63,7 +60,6 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor refreshPeriod: opts.RefreshPeriod, sessionRefresher: opts.RefreshSession, sessionValidator: opts.ValidateSession, - sessionEnricher: opts.EnrichSession, } return ss.loadSession } @@ -75,7 +71,6 @@ type storedSessionLoader struct { refreshPeriod time.Duration sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) sessionValidator func(context.Context, *sessionsapi.SessionState) bool - sessionEnricher func(context.Context, *sessionsapi.SessionState) error } // loadSession attempts to load a session as identified by the request cookies. @@ -235,14 +230,6 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R // (In case underlying provider implementations forget) session.CreatedAtNow() - // Re-enrich the session after a real token refresh so that providers - // which fetch extra data (e.g. groups via Graph API) can repopulate it. - if !errors.Is(err, providers.ErrNotImplemented) && s.sessionEnricher != nil { - if enrichErr := s.sessionEnricher(req.Context(), session); enrichErr != nil { - return fmt.Errorf("error enriching session after refresh: %v", enrichErr) - } - } - // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) if err != nil {