diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index c913a4ef..ef8eedd3 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -639,9 +639,11 @@ var _ = Describe("Stored Session Suite", func() { Context("refreshSession", func() { type refreshSessionWithProviderTableInput struct { - session *sessionsapi.SessionState - expectedErr error - expectSaved bool + session *sessionsapi.SessionState + expectedErr error + expectSaved bool + expectEnriched bool + sessionEnricher func(context.Context, *sessionsapi.SessionState) error } now := time.Now() @@ -649,6 +651,7 @@ var _ = Describe("Stored Session Suite", func() { DescribeTable("when refreshing with the provider", func(in refreshSessionWithProviderTableInput) { saved := false + enriched := false s := &storedSessionLoader{ store: &fakeSessionStore{ @@ -672,6 +675,13 @@ var _ = Describe("Stored Session Suite", func() { return false, errors.New("error refreshing session") } }, + sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error { + enriched = true + if in.sessionEnricher != nil { + return in.sessionEnricher(context.Background(), nil) + } + return nil + }, } req := httptest.NewRequest("", "/", nil) @@ -683,27 +693,31 @@ var _ = Describe("Stored Session Suite", func() { Expect(err).ToNot(HaveOccurred()) } Expect(saved).To(Equal(in.expectSaved)) + Expect(enriched).To(Equal(in.expectEnriched)) }, Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: noRefresh, }, - expectedErr: nil, - expectSaved: false, + expectedErr: nil, + expectSaved: false, + expectEnriched: false, }), Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, }, - expectedErr: nil, - expectSaved: true, + expectedErr: nil, + expectSaved: true, + expectEnriched: true, }), Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: notImplemented, }, - expectedErr: nil, - expectSaved: true, + expectedErr: nil, + expectSaved: true, + expectEnriched: false, }), Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ @@ -711,16 +725,29 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &now, ExpiresOn: &now, }, - expectedErr: errors.New("error refreshing tokens: error refreshing session"), - expectSaved: false, + expectedErr: errors.New("error refreshing tokens: error refreshing session"), + expectSaved: false, + expectEnriched: false, }), Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ session: &sessionsapi.SessionState{ RefreshToken: refresh, AccessToken: "NoSave", }, - expectedErr: errors.New("error saving session: unable to save session"), - expectSaved: true, + expectedErr: errors.New("error saving session: unable to save session"), + expectSaved: true, + expectEnriched: true, + }), + Entry("when enriching the session returns an error", refreshSessionWithProviderTableInput{ + session: &sessionsapi.SessionState{ + RefreshToken: refresh, + }, + sessionEnricher: func(_ context.Context, _ *sessionsapi.SessionState) error { + return errors.New("enrich error") + }, + expectedErr: errors.New("error enriching session after refresh: enrich error"), + expectSaved: false, + expectEnriched: true, }), ) })