call EnrichSession after session refresh
Signed-off-by: ugurtafrali <tafraliugur@gmail.com>
This commit is contained in:
parent
da9123f740
commit
8d19ba11f8
|
|
@ -421,6 +421,7 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi
|
|||
RefreshPeriod: opts.Cookie.Refresh,
|
||||
RefreshSession: provider.RefreshSession,
|
||||
ValidateSession: provider.ValidateSession,
|
||||
EnrichSession: provider.EnrichSession,
|
||||
}))
|
||||
|
||||
return chain
|
||||
|
|
|
|||
|
|
@ -48,6 +48,9 @@ type StoredSessionLoaderOptions struct {
|
|||
// If the sesssion is older than `RefreshPeriod` but the provider doesn't
|
||||
// refresh it, we must re-validate using this validation.
|
||||
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
|
||||
|
||||
// Provider based session enriching after a token refresh.
|
||||
EnrichSession func(context.Context, *sessionsapi.SessionState) error
|
||||
}
|
||||
|
||||
// NewStoredSessionLoader creates a new storedSessionLoader which loads
|
||||
|
|
@ -60,6 +63,7 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
|
|||
refreshPeriod: opts.RefreshPeriod,
|
||||
sessionRefresher: opts.RefreshSession,
|
||||
sessionValidator: opts.ValidateSession,
|
||||
sessionEnricher: opts.EnrichSession,
|
||||
}
|
||||
return ss.loadSession
|
||||
}
|
||||
|
|
@ -71,6 +75,7 @@ type storedSessionLoader struct {
|
|||
refreshPeriod time.Duration
|
||||
sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
|
||||
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
|
||||
sessionEnricher func(context.Context, *sessionsapi.SessionState) error
|
||||
}
|
||||
|
||||
// loadSession attempts to load a session as identified by the request cookies.
|
||||
|
|
@ -230,6 +235,14 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
|
|||
// (In case underlying provider implementations forget)
|
||||
session.CreatedAtNow()
|
||||
|
||||
// Re-enrich the session after a real token refresh so that providers
|
||||
// which fetch extra data (e.g. groups via Graph API) can repopulate it.
|
||||
if !errors.Is(err, providers.ErrNotImplemented) && s.sessionEnricher != nil {
|
||||
if enrichErr := s.sessionEnricher(req.Context(), session); enrichErr != nil {
|
||||
return fmt.Errorf("error enriching session after refresh: %v", enrichErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Because the session was refreshed, make sure to save it
|
||||
err = s.store.Save(rw, req, session)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -638,9 +638,11 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
|
||||
Context("refreshSession", func() {
|
||||
type refreshSessionWithProviderTableInput struct {
|
||||
session *sessionsapi.SessionState
|
||||
expectedErr error
|
||||
expectSaved bool
|
||||
session *sessionsapi.SessionState
|
||||
expectedErr error
|
||||
expectSaved bool
|
||||
expectEnriched bool
|
||||
sessionEnricher func(context.Context, *sessionsapi.SessionState) error
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
|
@ -648,6 +650,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
DescribeTable("when refreshing with the provider",
|
||||
func(in refreshSessionWithProviderTableInput) {
|
||||
saved := false
|
||||
enriched := false
|
||||
|
||||
s := &storedSessionLoader{
|
||||
store: &fakeSessionStore{
|
||||
|
|
@ -671,6 +674,13 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
return false, errors.New("error refreshing session")
|
||||
}
|
||||
},
|
||||
sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error {
|
||||
enriched = true
|
||||
if in.sessionEnricher != nil {
|
||||
return in.sessionEnricher(context.Background(), nil)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
|
|
@ -682,27 +692,31 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Expect(saved).To(Equal(in.expectSaved))
|
||||
Expect(enriched).To(Equal(in.expectEnriched))
|
||||
},
|
||||
Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: noRefresh,
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectSaved: false,
|
||||
expectedErr: nil,
|
||||
expectSaved: false,
|
||||
expectEnriched: false,
|
||||
}),
|
||||
Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: refresh,
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectSaved: true,
|
||||
expectedErr: nil,
|
||||
expectSaved: true,
|
||||
expectEnriched: true,
|
||||
}),
|
||||
Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: notImplemented,
|
||||
},
|
||||
expectedErr: nil,
|
||||
expectSaved: true,
|
||||
expectedErr: nil,
|
||||
expectSaved: true,
|
||||
expectEnriched: false,
|
||||
}),
|
||||
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
|
|
@ -710,16 +724,29 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
CreatedAt: &now,
|
||||
ExpiresOn: &now,
|
||||
},
|
||||
expectedErr: errors.New("error refreshing tokens: error refreshing session"),
|
||||
expectSaved: false,
|
||||
expectedErr: errors.New("error refreshing tokens: error refreshing session"),
|
||||
expectSaved: false,
|
||||
expectEnriched: false,
|
||||
}),
|
||||
Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: refresh,
|
||||
AccessToken: "NoSave",
|
||||
},
|
||||
expectedErr: errors.New("error saving session: unable to save session"),
|
||||
expectSaved: true,
|
||||
expectedErr: errors.New("error saving session: unable to save session"),
|
||||
expectSaved: true,
|
||||
expectEnriched: true,
|
||||
}),
|
||||
Entry("when enriching the session returns an error", refreshSessionWithProviderTableInput{
|
||||
session: &sessionsapi.SessionState{
|
||||
RefreshToken: refresh,
|
||||
},
|
||||
sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error {
|
||||
return errors.New("enrich error")
|
||||
},
|
||||
expectedErr: errors.New("error enriching session after refresh: enrich error"),
|
||||
expectSaved: false,
|
||||
expectEnriched: true,
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in New Issue