Use session to lock to protect concurrent refreshes
This commit is contained in:
		
							parent
							
								
									dc5d2a5cd7
								
							
						
					
					
						commit
						e2c7ff6ddd
					
				|  | @ -14,6 +14,11 @@ import ( | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/providers" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/providers" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | const ( | ||||||
|  | 	SessionLockExpireTime = 5 * time.Second | ||||||
|  | 	SessionLockPeekDelay  = 50 * time.Millisecond | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| // StoredSessionLoaderOptions contains all of the requirements to construct
 | // StoredSessionLoaderOptions contains all of the requirements to construct
 | ||||||
| // a stored session loader.
 | // a stored session loader.
 | ||||||
| // All options must be provided.
 | // All options must be provided.
 | ||||||
|  | @ -91,13 +96,10 @@ func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { | ||||||
| // that is is valid.
 | // that is is valid.
 | ||||||
| func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
| 	session, err := s.store.Load(req) | 	session, err := s.store.Load(req) | ||||||
| 	if err != nil { | 	if err != nil || session == nil { | ||||||
|  | 		// No session was found in the storage or error occurred, nothing more to do
 | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if session == nil { |  | ||||||
| 		// No session was found in the storage, nothing more to do
 |  | ||||||
| 		return nil, nil |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	err = s.refreshSessionIfNeeded(rw, req, session) | 	err = s.refreshSessionIfNeeded(rw, req, session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -116,12 +118,21 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) | 	wasRefreshed, err := s.checkForConcurrentRefresh(session, req) | ||||||
| 	err := s.refreshSession(rw, req, session) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// If a preemptive refresh fails, we still keep the session
 | 		return err | ||||||
| 		// if validateSession succeeds.
 | 	} | ||||||
| 		logger.Errorf("Unable to refresh session: %v", err) | 
 | ||||||
|  | 	// If session was already refreshed via a concurrent request locked skip refreshing,
 | ||||||
|  | 	// because the refreshed session is already loaded from storage
 | ||||||
|  | 	if !wasRefreshed { | ||||||
|  | 		logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) | ||||||
|  | 		err = s.refreshSession(rw, req, session) | ||||||
|  | 		if err != nil { | ||||||
|  | 			// If a preemptive refresh fails, we still keep the session
 | ||||||
|  | 			// if validateSession succeeds.
 | ||||||
|  | 			logger.Errorf("Unable to refresh session: %v", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Validate all sessions after any Redeem/Refresh operation (fail or success)
 | 	// Validate all sessions after any Redeem/Refresh operation (fail or success)
 | ||||||
|  | @ -131,6 +142,18 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req | ||||||
| // refreshSession attempts to refresh the session with the provider
 | // refreshSession attempts to refresh the session with the provider
 | ||||||
| // and will save the session if it was updated.
 | // and will save the session if it was updated.
 | ||||||
| func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 	err := session.ObtainLock(req.Context(), SessionLockExpireTime) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Unable to obtain lock: %v", err) | ||||||
|  | 		return s.handleObtainLockError(req, session) | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		err = session.ReleaseLock(req.Context()) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.Errorf("unable to release lock: %v", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
| 	refreshed, err := s.sessionRefresher(req.Context(), session) | 	refreshed, err := s.sessionRefresher(req.Context(), session) | ||||||
| 	if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | 	if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | ||||||
| 		return fmt.Errorf("error refreshing tokens: %v", err) | 		return fmt.Errorf("error refreshing tokens: %v", err) | ||||||
|  | @ -159,11 +182,75 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R | ||||||
| 	err = s.store.Save(rw, req, session) | 	err = s.store.Save(rw, req, session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | ||||||
| 		return fmt.Errorf("error saving session: %v", err) | 		err = fmt.Errorf("error saving session: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *storedSessionLoader) handleObtainLockError(req *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 	wasRefreshed, err := s.checkForConcurrentRefresh(session, req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Unable to wait for obtained lock: %v", err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if !wasRefreshed { | ||||||
|  | 		return errors.New("unable to obtain lock and session was also not refreshed via concurrent request") | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 	sessionStored, err := s.store.Load(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("unable to load updated session from store: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if sessionStored == nil { | ||||||
|  | 		return fmt.Errorf("no session available to udpate from store") | ||||||
|  | 	} | ||||||
|  | 	*session = *sessionStored | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) { | ||||||
|  | 	var wasLocked bool | ||||||
|  | 	isLocked, err := session.PeekLock(req.Context()) | ||||||
|  | 	for isLocked { | ||||||
|  | 		wasLocked = true | ||||||
|  | 		// delay next peek lock
 | ||||||
|  | 		time.Sleep(SessionLockPeekDelay) | ||||||
|  | 		isLocked, err = session.PeekLock(req.Context()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return wasLocked, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request.
 | ||||||
|  | func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) { | ||||||
|  | 	wasLocked, err := s.waitForPossibleSessionLock(session, req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	refreshed := false | ||||||
|  | 	if wasLocked { | ||||||
|  | 		logger.Printf("Update session from store instead of refreshing") | ||||||
|  | 		err = s.updateSessionFromStore(req, session) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.Errorf("Unable to update session from store: %v", err) | ||||||
|  | 			return false, err | ||||||
|  | 		} | ||||||
|  | 		refreshed = true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return refreshed, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // validateSession checks whether the session has expired and performs
 | // validateSession checks whether the session has expired and performs
 | ||||||
| // provider validation on the session.
 | // provider validation on the session.
 | ||||||
| // An error implies the session is not longer valid.
 | // An error implies the session is not longer valid.
 | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | @ -17,9 +18,104 @@ import ( | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | type TestLock struct { | ||||||
|  | 	Locked            bool | ||||||
|  | 	WasObtained       bool | ||||||
|  | 	WasRefreshed      bool | ||||||
|  | 	WasReleased       bool | ||||||
|  | 	PeekedCount       int | ||||||
|  | 	LockedOnPeekCount int | ||||||
|  | 	ObtainError       error | ||||||
|  | 	PeekError         error | ||||||
|  | 	RefreshError      error | ||||||
|  | 	ReleaseError      error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error { | ||||||
|  | 	if l.ObtainError != nil { | ||||||
|  | 		return l.ObtainError | ||||||
|  | 	} | ||||||
|  | 	l.Locked = true | ||||||
|  | 	l.WasObtained = true | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *TestLock) Peek(_ context.Context) (bool, error) { | ||||||
|  | 	if l.PeekError != nil { | ||||||
|  | 		return false, l.PeekError | ||||||
|  | 	} | ||||||
|  | 	locked := l.Locked | ||||||
|  | 	l.Locked = false | ||||||
|  | 	l.PeekedCount++ | ||||||
|  | 	// mainly used to test case when peek initially returns false,
 | ||||||
|  | 	// but when trying to obtain lock, it returns true.
 | ||||||
|  | 	if l.LockedOnPeekCount == l.PeekedCount { | ||||||
|  | 		return true, nil | ||||||
|  | 	} | ||||||
|  | 	return locked, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *TestLock) Refresh(_ context.Context, _ time.Duration) error { | ||||||
|  | 	if l.RefreshError != nil { | ||||||
|  | 		return l.ReleaseError | ||||||
|  | 	} | ||||||
|  | 	l.WasRefreshed = true | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *TestLock) Release(_ context.Context) error { | ||||||
|  | 	if l.ReleaseError != nil { | ||||||
|  | 		return l.ReleaseError | ||||||
|  | 	} | ||||||
|  | 	l.Locked = false | ||||||
|  | 	l.WasReleased = true | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type LockConc struct { | ||||||
|  | 	mu          sync.Mutex | ||||||
|  | 	lock        bool | ||||||
|  | 	disablePeek bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error { | ||||||
|  | 	l.mu.Lock() | ||||||
|  | 	if l.lock { | ||||||
|  | 		l.mu.Unlock() | ||||||
|  | 		return sessionsapi.ErrLockNotObtained | ||||||
|  | 	} | ||||||
|  | 	l.lock = true | ||||||
|  | 	l.mu.Unlock() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LockConc) Peek(_ context.Context) (bool, error) { | ||||||
|  | 	var response bool | ||||||
|  | 	l.mu.Lock() | ||||||
|  | 	if l.disablePeek { | ||||||
|  | 		response = false | ||||||
|  | 	} else { | ||||||
|  | 		response = l.lock | ||||||
|  | 	} | ||||||
|  | 	l.mu.Unlock() | ||||||
|  | 	return response, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LockConc) Release(_ context.Context) error { | ||||||
|  | 	l.mu.Lock() | ||||||
|  | 	l.lock = false | ||||||
|  | 	l.mu.Unlock() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| var _ = Describe("Stored Session Suite", func() { | var _ = Describe("Stored Session Suite", func() { | ||||||
| 	const ( | 	const ( | ||||||
| 		refresh        = "Refresh" | 		refresh        = "Refresh" | ||||||
|  | 		refreshed      = "Refreshed" | ||||||
| 		noRefresh      = "NoRefresh" | 		noRefresh      = "NoRefresh" | ||||||
| 		notImplemented = "NotImplemented" | 		notImplemented = "NotImplemented" | ||||||
| 	) | 	) | ||||||
|  | @ -34,7 +130,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
| 			switch ss.RefreshToken { | 			switch ss.RefreshToken { | ||||||
| 			case refresh: | 			case refresh: | ||||||
| 				ss.RefreshToken = "Refreshed" | 				ss.RefreshToken = refreshed | ||||||
| 				return true, nil | 				return true, nil | ||||||
| 			case noRefresh: | 			case noRefresh: | ||||||
| 				return false, nil | 				return false, nil | ||||||
|  | @ -181,6 +277,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					ExpiresOn:    &createdFuture, | 					ExpiresOn:    &createdFuture, | ||||||
|  | 					Lock:         &sessionsapi.NoOpLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				store:           defaultSessionStore, | 				store:           defaultSessionStore, | ||||||
| 				refreshPeriod:   1 * time.Minute, | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | @ -222,6 +319,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: "Refreshed", | 					RefreshToken: "Refreshed", | ||||||
| 					CreatedAt:    &now, | 					CreatedAt:    &now, | ||||||
| 					ExpiresOn:    &createdFuture, | 					ExpiresOn:    &createdFuture, | ||||||
|  | 					Lock:         &sessionsapi.NoOpLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				store:           defaultSessionStore, | 				store:           defaultSessionStore, | ||||||
| 				refreshPeriod:   1 * time.Minute, | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | @ -237,6 +335,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: "RefreshError", | 					RefreshToken: "RefreshError", | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					ExpiresOn:    &createdFuture, | 					ExpiresOn:    &createdFuture, | ||||||
|  | 					Lock:         &sessionsapi.NoOpLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				store:           defaultSessionStore, | 				store:           defaultSessionStore, | ||||||
| 				refreshPeriod:   1 * time.Minute, | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | @ -266,15 +365,109 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				validateSession: defaultValidateFunc, | 				validateSession: defaultValidateFunc, | ||||||
| 			}), | 			}), | ||||||
| 		) | 		) | ||||||
|  | 
 | ||||||
|  | 		type storedSessionLoaderConcurrentTableInput struct { | ||||||
|  | 			existingSession *sessionsapi.SessionState | ||||||
|  | 			refreshPeriod   time.Duration | ||||||
|  | 			numConcReqs     int | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("when serving concurrent requests", | ||||||
|  | 			func(in storedSessionLoaderConcurrentTableInput) { | ||||||
|  | 				lockConc := &LockConc{} | ||||||
|  | 
 | ||||||
|  | 				refreshedChan := make(chan bool, in.numConcReqs) | ||||||
|  | 				for i := 0; i < in.numConcReqs; i++ { | ||||||
|  | 					go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { | ||||||
|  | 						existingSession := *in.existingSession // deep copy existingSession state
 | ||||||
|  | 						existingSession.Lock = lockConc | ||||||
|  | 						store := &fakeSessionStore{ | ||||||
|  | 							LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 								return &existingSession, nil | ||||||
|  | 							}, | ||||||
|  | 							SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error { | ||||||
|  | 								return nil | ||||||
|  | 							}, | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						scope := &middlewareapi.RequestScope{ | ||||||
|  | 							Session: nil, | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						// Set up the request with the request header and a request scope
 | ||||||
|  | 						req := httptest.NewRequest("", "/", nil) | ||||||
|  | 						req = middlewareapi.AddRequestScope(req, scope) | ||||||
|  | 
 | ||||||
|  | 						rw := httptest.NewRecorder() | ||||||
|  | 
 | ||||||
|  | 						sessionRefreshed := false | ||||||
|  | 						opts := &StoredSessionLoaderOptions{ | ||||||
|  | 							SessionStore:  store, | ||||||
|  | 							RefreshPeriod: in.refreshPeriod, | ||||||
|  | 							RefreshSession: func(ctx context.Context, s *sessionsapi.SessionState) (bool, error) { | ||||||
|  | 								time.Sleep(10 * time.Millisecond) | ||||||
|  | 								sessionRefreshed = true | ||||||
|  | 								return true, nil | ||||||
|  | 							}, | ||||||
|  | 							ValidateSession: func(context.Context, *sessionsapi.SessionState) bool { | ||||||
|  | 								return true | ||||||
|  | 							}, | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) | ||||||
|  | 						handler.ServeHTTP(rw, req) | ||||||
|  | 
 | ||||||
|  | 						refreshedChan <- sessionRefreshed | ||||||
|  | 					}(refreshedChan, lockConc) | ||||||
|  | 				} | ||||||
|  | 				var refreshedSlice []bool | ||||||
|  | 				for i := 0; i < in.numConcReqs; i++ { | ||||||
|  | 					refreshedSlice = append(refreshedSlice, <-refreshedChan) | ||||||
|  | 				} | ||||||
|  | 				sessionRefreshedCount := 0 | ||||||
|  | 				for _, sessionRefreshed := range refreshedSlice { | ||||||
|  | 					if sessionRefreshed { | ||||||
|  | 						sessionRefreshedCount++ | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				Expect(sessionRefreshedCount).To(Equal(1)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("with two concurrent requests", storedSessionLoaderConcurrentTableInput{ | ||||||
|  | 				existingSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				numConcReqs:   2, | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with 5 concurrent requests", storedSessionLoaderConcurrentTableInput{ | ||||||
|  | 				existingSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				numConcReqs:   5, | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with one request", storedSessionLoaderConcurrentTableInput{ | ||||||
|  | 				existingSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				numConcReqs:   1, | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	Context("refreshSessionIfNeeded", func() { | 	Context("refreshSessionIfNeeded", func() { | ||||||
| 		type refreshSessionIfNeededTableInput struct { | 		type refreshSessionIfNeededTableInput struct { | ||||||
| 			refreshPeriod   time.Duration | 			refreshPeriod     time.Duration | ||||||
| 			session         *sessionsapi.SessionState | 			sessionStored     bool | ||||||
| 			expectedErr     error | 			session           *sessionsapi.SessionState | ||||||
| 			expectRefreshed bool | 			expectedErr       error | ||||||
| 			expectValidated bool | 			expectRefreshed   bool | ||||||
|  | 			expectValidated   bool | ||||||
|  | 			expectedLockState TestLock | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		createdPast := time.Now().Add(-5 * time.Minute) | 		createdPast := time.Now().Add(-5 * time.Minute) | ||||||
|  | @ -285,9 +478,18 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				refreshed := false | 				refreshed := false | ||||||
| 				validated := false | 				validated := false | ||||||
| 
 | 
 | ||||||
|  | 				store := &fakeSessionStore{} | ||||||
|  | 				if in.sessionStored { | ||||||
|  | 					store = &fakeSessionStore{ | ||||||
|  | 						LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 							return in.session, nil | ||||||
|  | 						}, | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
| 				s := &storedSessionLoader{ | 				s := &storedSessionLoader{ | ||||||
| 					refreshPeriod: in.refreshPeriod, | 					refreshPeriod: in.refreshPeriod, | ||||||
| 					store:         &fakeSessionStore{}, | 					store:         store, | ||||||
| 					sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | 					sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
| 						refreshed = true | 						refreshed = true | ||||||
| 						switch ss.RefreshToken { | 						switch ss.RefreshToken { | ||||||
|  | @ -316,46 +518,117 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				} | 				} | ||||||
| 				Expect(refreshed).To(Equal(in.expectRefreshed)) | 				Expect(refreshed).To(Equal(in.expectRefreshed)) | ||||||
| 				Expect(validated).To(Equal(in.expectValidated)) | 				Expect(validated).To(Equal(in.expectValidated)) | ||||||
|  | 				testLock, ok := in.session.Lock.(*TestLock) | ||||||
|  | 				Expect(ok).To(Equal(true)) | ||||||
|  | 
 | ||||||
|  | 				Expect(testLock).To(Equal(&in.expectedLockState)) | ||||||
| 			}, | 			}, | ||||||
| 			Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ | 			Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: time.Duration(0), | 				refreshPeriod: time.Duration(0), | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 					CreatedAt:    &createdFuture, | 					CreatedAt:    &createdFuture, | ||||||
|  | 					Lock:         &TestLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:       nil, | ||||||
| 				expectRefreshed: false, | 				expectRefreshed:   false, | ||||||
| 				expectValidated: false, | 				expectValidated:   false, | ||||||
|  | 				expectedLockState: TestLock{}, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ | 			Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: time.Duration(0), | 				refreshPeriod: time.Duration(0), | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
|  | 					Lock:         &TestLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:       nil, | ||||||
| 				expectRefreshed: false, | 				expectRefreshed:   false, | ||||||
| 				expectValidated: false, | 				expectValidated:   false, | ||||||
|  | 				expectedLockState: TestLock{}, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ | 			Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 					CreatedAt:    &createdFuture, | 					CreatedAt:    &createdFuture, | ||||||
|  | 					Lock:         &TestLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:       nil, | ||||||
| 				expectRefreshed: false, | 				expectRefreshed:   false, | ||||||
| 				expectValidated: false, | 				expectValidated:   false, | ||||||
|  | 				expectedLockState: TestLock{}, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ | 			Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
|  | 					Lock:         &TestLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:     nil, | ||||||
| 				expectRefreshed: true, | 				expectRefreshed: true, | ||||||
| 				expectValidated: true, | 				expectValidated: true, | ||||||
|  | 				expectedLockState: TestLock{ | ||||||
|  | 					Locked:      false, | ||||||
|  | 					WasObtained: true, | ||||||
|  | 					WasReleased: true, | ||||||
|  | 					PeekedCount: 1, | ||||||
|  | 				}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the session is locked and instead loaded from storage", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: noRefresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 					Lock: &TestLock{ | ||||||
|  | 						Locked: true, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				sessionStored:   true, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectValidated: true, | ||||||
|  | 				expectedLockState: TestLock{ | ||||||
|  | 					Locked:      false, | ||||||
|  | 					PeekedCount: 2, | ||||||
|  | 				}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: noRefresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 					Lock: &TestLock{ | ||||||
|  | 						ObtainError:       errors.New("not able to obtain lock"), | ||||||
|  | 						LockedOnPeekCount: 2, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectValidated: true, | ||||||
|  | 				expectedLockState: TestLock{ | ||||||
|  | 					PeekedCount:       3, | ||||||
|  | 					LockedOnPeekCount: 2, | ||||||
|  | 					ObtainError:       errors.New("not able to obtain lock"), | ||||||
|  | 				}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: noRefresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 					Lock: &TestLock{ | ||||||
|  | 						ObtainError: errors.New("not able to obtain lock"), | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectValidated: true, | ||||||
|  | 				expectedLockState: TestLock{ | ||||||
|  | 					PeekedCount: 2, | ||||||
|  | 					ObtainError: errors.New("not able to obtain lock"), | ||||||
|  | 				}, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | @ -363,42 +636,53 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					ExpiresOn:    &createdFuture, | 					ExpiresOn:    &createdFuture, | ||||||
|  | 					Lock:         &TestLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:     nil, | ||||||
| 				expectRefreshed: true, | 				expectRefreshed: true, | ||||||
| 				expectValidated: true, | 				expectValidated: true, | ||||||
|  | 				expectedLockState: TestLock{ | ||||||
|  | 					Locked:      false, | ||||||
|  | 					WasObtained: true, | ||||||
|  | 					WasReleased: true, | ||||||
|  | 					PeekedCount: 1, | ||||||
|  | 				}, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{ | 			Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: notImplemented, | 					RefreshToken: notImplemented, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
|  | 					Lock:         &TestLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:     nil, | ||||||
| 				expectRefreshed: true, | 				expectRefreshed: true, | ||||||
| 				expectValidated: true, | 				expectValidated: true, | ||||||
| 			}), | 				expectedLockState: TestLock{ | ||||||
| 			Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ | 					Locked:      false, | ||||||
| 				refreshPeriod: 1 * time.Minute, | 					WasObtained: true, | ||||||
| 				session: &sessionsapi.SessionState{ | 					WasReleased: true, | ||||||
| 					RefreshToken: "RefreshError", | 					PeekedCount: 1, | ||||||
| 					CreatedAt:    &createdPast, |  | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, |  | ||||||
| 				expectRefreshed: true, |  | ||||||
| 				expectValidated: true, |  | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ | 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					AccessToken:  "Invalid", | 					AccessToken:  "Invalid", | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					ExpiresOn:    &createdFuture, | 					ExpiresOn:    &createdFuture, | ||||||
|  | 					Lock:         &TestLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     errors.New("session is invalid"), | 				expectedErr:     errors.New("session is invalid"), | ||||||
| 				expectRefreshed: true, | 				expectRefreshed: true, | ||||||
| 				expectValidated: true, | 				expectValidated: true, | ||||||
|  | 				expectedLockState: TestLock{ | ||||||
|  | 					Locked:      false, | ||||||
|  | 					WasObtained: true, | ||||||
|  | 					WasReleased: true, | ||||||
|  | 					PeekedCount: 1, | ||||||
|  | 				}, | ||||||
| 			}), | 			}), | ||||||
| 		) | 		) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue