diff --git a/CHANGELOG.md b/CHANGELOG.md index 05ae8c92..690b4831 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.2.1 +- [#1468](https://github.com/oauth2-proxy/oauth2-proxy/pull/1468) Implement session locking with session state lock (@JoelSpeed, @Bibob7) - [#1489](https://github.com/oauth2-proxy/oauth2-proxy/pull/1489) Fix Docker Buildx push to include build version (@JoelSpeed) - [#1477](https://github.com/oauth2-proxy/oauth2-proxy/pull/1477) Remove provider documentation for `Microsoft Azure AD` (@omBratteng) - [#1204](https://github.com/oauth2-proxy/oauth2-proxy/pull/1204) Added configuration for audience claim (`--oidc-extra-audience`) and ability to specify extra audiences (`--oidc-extra-audience`) allowed passing audience verification. This enables support for AWS Cognito and other issuers that have custom audience claims. Also, this adds the ability to allow multiple audiences. (@kschu91) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 4cbc47eb..1afe6d0c 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -14,6 +14,23 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) +const ( + // When attempting to obtain the lock, if it's not done before this timeout + // then exit and fail the refresh attempt. + // TODO: This should probably be configurable by the end user. + sessionRefreshObtainTimeout = 5 * time.Second + + // Maximum time allowed for a session refresh attempt. + // If the refresh request isn't finished within this time, the lock will be + // released. + // TODO: This should probably be configurable by the end user. + sessionRefreshLockDuration = 2 * time.Second + + // How long to wait after failing to obtain the lock before trying again. + // TODO: This should probably be configurable by the end user. + sessionRefreshRetryPeriod = 10 * time.Millisecond +) + // StoredSessionLoaderOptions contains all of the requirements to construct // a stored session loader. // All options must be provided. @@ -91,13 +108,10 @@ func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { // that is is valid. func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { session, err := s.store.Load(req) - if err != nil { + if err != nil || session == nil { + // No session was found in the storage or error occurred, nothing more to do return nil, err } - if session == nil { - // No session was found in the storage, nothing more to do - return nil, nil - } err = s.refreshSessionIfNeeded(rw, req, session) if err != nil { @@ -111,14 +125,69 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h // is older than the refresh period. // Success or fail, we will then validate the session. func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { - if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { + if !needsRefresh(s.refreshPeriod, session) { // Refresh is disabled or the session is not old enough, do nothing return nil } - logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) - err := s.refreshSession(rw, req, session) + var lockObtained bool + ctx, cancel := context.WithTimeout(context.Background(), sessionRefreshObtainTimeout) + defer cancel() + + for !lockObtained { + select { + case <-ctx.Done(): + return errors.New("timeout obtaining session lock") + default: + err := session.ObtainLock(req.Context(), sessionRefreshLockDuration) + if err != nil && !errors.Is(err, sessionsapi.ErrLockNotObtained) { + return fmt.Errorf("error occurred while trying to obtain lock: %v", err) + } else if errors.Is(err, sessionsapi.ErrLockNotObtained) { + time.Sleep(sessionRefreshRetryPeriod) + continue + } + // No error means we obtained the lock + lockObtained = true + } + } + + // The rest of this function is carried out under lock, but we must release it + // wherever we exit from this function. + defer func() { + if session == nil { + return + } + if err := session.ReleaseLock(req.Context()); err != nil { + logger.Errorf("unable to release lock: %v", err) + } + }() + + // Reload the session in case it was changed underneath us. + freshSession, err := s.store.Load(req) if err != nil { + return fmt.Errorf("could not load session: %v", err) + } + if freshSession == nil { + return errors.New("session no longer exists, it may have been removed by another request") + } + // Restore the state of the fresh session into the original pointer. + // This is important so that changes are passed up the to the parent scope. + lock := session.Lock + *session = *freshSession + + // Ensure we maintain the session lock after we have refreshed the session. + // Loading from the session store creates a new lock in the session. + session.Lock = lock + + if !needsRefresh(s.refreshPeriod, session) { + // The session must have already been refreshed while we were waiting to + // obtain the lock. + return nil + } + + // We are holding the lock and the session needs a refresh + logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + if err := s.refreshSession(rw, req, session); err != nil { // If a preemptive refresh fails, we still keep the session // if validateSession succeeds. logger.Errorf("Unable to refresh session: %v", err) @@ -128,6 +197,11 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req return s.validateSession(req.Context(), session) } +// needsRefresh determines whether we should attempt to refresh a session or not. +func needsRefresh(refreshPeriod time.Duration, session *sessionsapi.SessionState) bool { + return refreshPeriod > time.Duration(0) && session.Age() > refreshPeriod +} + // refreshSession attempts to refresh the session with the provider // and will save the session if it was updated. func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 782390b6..c462278f 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "time" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" @@ -17,9 +18,74 @@ import ( . "github.com/onsi/gomega" ) +type testLock struct { + locked bool + obtainOnAttempt int + obtainAttempts int + obtainError error +} + +func (l *testLock) Obtain(_ context.Context, _ time.Duration) error { + l.obtainAttempts++ + if l.obtainAttempts < l.obtainOnAttempt { + return sessionsapi.ErrLockNotObtained + } + if l.obtainError != nil { + return l.obtainError + } + l.locked = true + return nil +} + +func (l *testLock) Peek(_ context.Context) (bool, error) { + return l.locked, nil +} + +func (l *testLock) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *testLock) Release(_ context.Context) error { + l.locked = false + return nil +} + +type testLockConcurrent struct { + mu sync.RWMutex + locked bool +} + +func (l *testLockConcurrent) Obtain(_ context.Context, _ time.Duration) error { + l.mu.Lock() + defer l.mu.Unlock() + if l.locked { + return sessionsapi.ErrLockNotObtained + } + l.locked = true + return nil +} + +func (l *testLockConcurrent) Peek(_ context.Context) (bool, error) { + l.mu.RLock() + defer l.mu.RUnlock() + return l.locked, nil +} + +func (l *testLockConcurrent) Refresh(_ context.Context, _ time.Duration) error { + return nil +} + +func (l *testLockConcurrent) Release(_ context.Context) error { + l.mu.Lock() + defer l.mu.Unlock() + l.locked = false + return nil +} + var _ = Describe("Stored Session Suite", func() { const ( refresh = "Refresh" + refreshed = "Refreshed" noRefresh = "NoRefresh" notImplemented = "NotImplemented" ) @@ -34,7 +100,7 @@ var _ = Describe("Stored Session Suite", func() { var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: - ss.RefreshToken = "Refreshed" + ss.RefreshToken = refreshed return true, nil case noRefresh: return false, nil @@ -181,6 +247,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -222,6 +289,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: "Refreshed", CreatedAt: &now, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -237,6 +305,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: "RefreshError", CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, }, store: defaultSessionStore, refreshPeriod: 1 * time.Minute, @@ -266,15 +335,116 @@ var _ = Describe("Stored Session Suite", func() { validateSession: defaultValidateFunc, }), ) + + type storedSessionLoaderConcurrentTableInput struct { + existingSession *sessionsapi.SessionState + refreshPeriod time.Duration + numConcReqs int + } + + DescribeTable("when serving concurrent requests", + func(in storedSessionLoaderConcurrentTableInput) { + lockConc := &testLockConcurrent{} + + lock := &sync.RWMutex{} + existingSession := *in.existingSession // deep copy existingSession state + existingSession.Lock = lockConc + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + lock.RLock() + defer lock.RUnlock() + session := existingSession + return &session, nil + }, + SaveFunc: func(_ http.ResponseWriter, _ *http.Request, session *sessionsapi.SessionState) error { + lock.Lock() + defer lock.Unlock() + existingSession = *session + return nil + }, + } + + refreshedChan := make(chan bool, in.numConcReqs) + for i := 0; i < in.numConcReqs; i++ { + go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { + scope := &middlewareapi.RequestScope{ + Session: nil, + } + + // Set up the request with the request header and a request scope + req := httptest.NewRequest("", "/", nil) + req = middlewareapi.AddRequestScope(req, scope) + + rw := httptest.NewRecorder() + + sessionRefreshed := false + opts := &StoredSessionLoaderOptions{ + SessionStore: store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: func(ctx context.Context, s *sessionsapi.SessionState) (bool, error) { + time.Sleep(10 * time.Millisecond) + sessionRefreshed = true + return true, nil + }, + ValidateSession: func(context.Context, *sessionsapi.SessionState) bool { + return true + }, + } + + handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + handler.ServeHTTP(rw, req) + + refreshedChan <- sessionRefreshed + }(refreshedChan, lockConc) + } + var refreshedSlice []bool + for i := 0; i < in.numConcReqs; i++ { + refreshedSlice = append(refreshedSlice, <-refreshedChan) + } + sessionRefreshedCount := 0 + for _, sessionRefreshed := range refreshedSlice { + if sessionRefreshed { + sessionRefreshedCount++ + } + } + Expect(sessionRefreshedCount).To(Equal(1)) + }, + Entry("with two concurrent requests", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 2, + refreshPeriod: 1 * time.Minute, + }), + Entry("with 5 concurrent requests", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 5, + refreshPeriod: 1 * time.Minute, + }), + Entry("with one request", storedSessionLoaderConcurrentTableInput{ + existingSession: &sessionsapi.SessionState{ + RefreshToken: refresh, + CreatedAt: &createdPast, + }, + numConcReqs: 1, + refreshPeriod: 1 * time.Minute, + }), + ) }) Context("refreshSessionIfNeeded", func() { type refreshSessionIfNeededTableInput struct { - refreshPeriod time.Duration - session *sessionsapi.SessionState - expectedErr error - expectRefreshed bool - expectValidated bool + refreshPeriod time.Duration + session *sessionsapi.SessionState + concurrentSessionRefresh bool + expectedErr error + expectRefreshed bool + expectValidated bool + expectedLockObtained bool } createdPast := time.Now().Add(-5 * time.Minute) @@ -285,9 +455,28 @@ var _ = Describe("Stored Session Suite", func() { refreshed := false validated := false + session := &sessionsapi.SessionState{} + *session = *in.session + if in.concurrentSessionRefresh { + // Update the session that Load returns. + // This simulates a concurrent refresh in the background. + session.CreatedAt = &createdFuture + } + store := &fakeSessionStore{ + LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { + // Loading the session from the provider creates a new lock + session.Lock = &testLock{} + return session, nil + }, + SaveFunc: func(_ http.ResponseWriter, _ *http.Request, s *sessionsapi.SessionState) error { + *session = *s + return nil + }, + } + s := &storedSessionLoader{ refreshPeriod: in.refreshPeriod, - store: &fakeSessionStore{}, + store: store, sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken { @@ -316,46 +505,90 @@ var _ = Describe("Stored Session Suite", func() { } Expect(refreshed).To(Equal(in.expectRefreshed)) Expect(validated).To(Equal(in.expectValidated)) + testLock, ok := in.session.Lock.(*testLock) + Expect(ok).To(Equal(true)) + + if in.expectedLockObtained { + Expect(testLock.obtainAttempts).Should(BeNumerically(">", 0), "Expected at least one attempt at obtaining the session lock") + } + Expect(testLock.locked).To(BeFalse(), "Expected lock should always be released") }, Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: time.Duration(0), session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdFuture, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: false, - expectValidated: false, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: false, }), Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, + }), + Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &testLock{ + obtainOnAttempt: 4, + }, + }, + concurrentSessionRefresh: true, + expectedErr: nil, + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: true, + }), + Entry("when obtaining lock failed with a valid session", refreshSessionIfNeededTableInput{ + refreshPeriod: 1 * time.Minute, + session: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + Lock: &testLock{ + obtainError: sessionsapi.ErrLockNotObtained, + }, + }, + expectedErr: errors.New("timeout obtaining session lock"), + expectRefreshed: false, + expectValidated: false, + expectedLockObtained: true, }), Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, @@ -363,42 +596,38 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), - Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{ + Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ RefreshToken: notImplemented, CreatedAt: &createdPast, + Lock: &testLock{}, }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, + expectedErr: nil, + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), - Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ - refreshPeriod: 1 * time.Minute, - session: &sessionsapi.SessionState{ - RefreshToken: "RefreshError", - CreatedAt: &createdPast, - }, - expectedErr: nil, - expectRefreshed: true, - expectValidated: true, - }), - Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ + Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ refreshPeriod: 1 * time.Minute, session: &sessionsapi.SessionState{ AccessToken: "Invalid", RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Lock: &testLock{}, }, - expectedErr: errors.New("session is invalid"), - expectRefreshed: true, - expectValidated: true, + expectedErr: errors.New("session is invalid"), + expectRefreshed: true, + expectValidated: true, + expectedLockObtained: true, }), ) })