call EnrichSession after session refresh

Signed-off-by: ugurtafrali <tafraliugur@gmail.com>
This commit is contained in:
ugurtafrali 2026-04-07 22:01:56 +03:00
parent da9123f740
commit 8d19ba11f8
3 changed files with 54 additions and 13 deletions

View File

@ -421,6 +421,7 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
EnrichSession: provider.EnrichSession,
}))
return chain

View File

@ -48,6 +48,9 @@ type StoredSessionLoaderOptions struct {
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
// refresh it, we must re-validate using this validation.
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
// Provider based session enriching after a token refresh.
EnrichSession func(context.Context, *sessionsapi.SessionState) error
}
// NewStoredSessionLoader creates a new storedSessionLoader which loads
@ -60,6 +63,7 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
refreshPeriod: opts.RefreshPeriod,
sessionRefresher: opts.RefreshSession,
sessionValidator: opts.ValidateSession,
sessionEnricher: opts.EnrichSession,
}
return ss.loadSession
}
@ -71,6 +75,7 @@ type storedSessionLoader struct {
refreshPeriod time.Duration
sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
sessionEnricher func(context.Context, *sessionsapi.SessionState) error
}
// loadSession attempts to load a session as identified by the request cookies.
@ -230,6 +235,14 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
// (In case underlying provider implementations forget)
session.CreatedAtNow()
// Re-enrich the session after a real token refresh so that providers
// which fetch extra data (e.g. groups via Graph API) can repopulate it.
if !errors.Is(err, providers.ErrNotImplemented) && s.sessionEnricher != nil {
if enrichErr := s.sessionEnricher(req.Context(), session); enrichErr != nil {
return fmt.Errorf("error enriching session after refresh: %v", enrichErr)
}
}
// Because the session was refreshed, make sure to save it
err = s.store.Save(rw, req, session)
if err != nil {

View File

@ -638,9 +638,11 @@ var _ = Describe("Stored Session Suite", func() {
Context("refreshSession", func() {
type refreshSessionWithProviderTableInput struct {
session *sessionsapi.SessionState
expectedErr error
expectSaved bool
session *sessionsapi.SessionState
expectedErr error
expectSaved bool
expectEnriched bool
sessionEnricher func(context.Context, *sessionsapi.SessionState) error
}
now := time.Now()
@ -648,6 +650,7 @@ var _ = Describe("Stored Session Suite", func() {
DescribeTable("when refreshing with the provider",
func(in refreshSessionWithProviderTableInput) {
saved := false
enriched := false
s := &storedSessionLoader{
store: &fakeSessionStore{
@ -671,6 +674,13 @@ var _ = Describe("Stored Session Suite", func() {
return false, errors.New("error refreshing session")
}
},
sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error {
enriched = true
if in.sessionEnricher != nil {
return in.sessionEnricher(context.Background(), nil)
}
return nil
},
}
req := httptest.NewRequest("", "/", nil)
@ -682,27 +692,31 @@ var _ = Describe("Stored Session Suite", func() {
Expect(err).ToNot(HaveOccurred())
}
Expect(saved).To(Equal(in.expectSaved))
Expect(enriched).To(Equal(in.expectEnriched))
},
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
RefreshToken: noRefresh,
},
expectedErr: nil,
expectSaved: false,
expectedErr: nil,
expectSaved: false,
expectEnriched: false,
}),
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
RefreshToken: refresh,
},
expectedErr: nil,
expectSaved: true,
expectedErr: nil,
expectSaved: true,
expectEnriched: true,
}),
Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
RefreshToken: notImplemented,
},
expectedErr: nil,
expectSaved: true,
expectedErr: nil,
expectSaved: true,
expectEnriched: false,
}),
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
@ -710,16 +724,29 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &now,
ExpiresOn: &now,
},
expectedErr: errors.New("error refreshing tokens: error refreshing session"),
expectSaved: false,
expectedErr: errors.New("error refreshing tokens: error refreshing session"),
expectSaved: false,
expectEnriched: false,
}),
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
RefreshToken: refresh,
AccessToken: "NoSave",
},
expectedErr: errors.New("error saving session: unable to save session"),
expectSaved: true,
expectedErr: errors.New("error saving session: unable to save session"),
expectSaved: true,
expectEnriched: true,
}),
Entry("when enriching the session returns an error", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
RefreshToken: refresh,
},
sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error {
return errors.New("enrich error")
},
expectedErr: errors.New("error enriching session after refresh: enrich error"),
expectSaved: false,
expectEnriched: true,
}),
)
})