Implement refresh relying on obtaining lock
This commit is contained in:
		
							parent
							
								
									e2c7ff6ddd
								
							
						
					
					
						commit
						54d42c5829
					
				|  | @ -15,8 +15,20 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	SessionLockExpireTime = 5 * time.Second | 	// When attempting to obtain the lock, if it's not done before this timeout
 | ||||||
| 	SessionLockPeekDelay  = 50 * time.Millisecond | 	// then exit and fail the refresh attempt.
 | ||||||
|  | 	// TODO: This should probably be configurable by the end user.
 | ||||||
|  | 	sessionRefreshObtainTimeout = 5 * time.Second | ||||||
|  | 
 | ||||||
|  | 	// Maximum time allowed for a session refresh attempt.
 | ||||||
|  | 	// If the refresh request isn't finished within this time, the lock will be
 | ||||||
|  | 	// released.
 | ||||||
|  | 	// TODO: This should probably be configurable by the end user.
 | ||||||
|  | 	sessionRefreshLockDuration = 2 * time.Second | ||||||
|  | 
 | ||||||
|  | 	// How long to wait after failing to obtain the lock before trying again.
 | ||||||
|  | 	// TODO: This should probably be configurable by the end user.
 | ||||||
|  | 	sessionRefreshRetryPeriod = 10 * time.Millisecond | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // StoredSessionLoaderOptions contains all of the requirements to construct
 | // StoredSessionLoaderOptions contains all of the requirements to construct
 | ||||||
|  | @ -113,47 +125,86 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h | ||||||
| // is older than the refresh period.
 | // is older than the refresh period.
 | ||||||
| // Success or fail, we will then validate the session.
 | // Success or fail, we will then validate the session.
 | ||||||
| func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||||
| 	if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { | 	if !needsRefresh(s.refreshPeriod, session) { | ||||||
| 		// Refresh is disabled or the session is not old enough, do nothing
 | 		// Refresh is disabled or the session is not old enough, do nothing
 | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	wasRefreshed, err := s.checkForConcurrentRefresh(session, req) | 	var lockObtained bool | ||||||
| 	if err != nil { | 	ctx, cancel := context.WithTimeout(context.Background(), sessionRefreshObtainTimeout) | ||||||
| 		return err | 	defer cancel() | ||||||
|  | 
 | ||||||
|  | 	for !lockObtained { | ||||||
|  | 		select { | ||||||
|  | 		case <-ctx.Done(): | ||||||
|  | 			return errors.New("timeout obtaining session lock") | ||||||
|  | 		default: | ||||||
|  | 			err := session.ObtainLock(req.Context(), sessionRefreshLockDuration) | ||||||
|  | 			if err != nil && !errors.Is(err, sessionsapi.ErrLockNotObtained) { | ||||||
|  | 				return fmt.Errorf("error occurred while trying to obtain lock: %v", err) | ||||||
|  | 			} else if errors.Is(err, sessionsapi.ErrLockNotObtained) { | ||||||
|  | 				time.Sleep(sessionRefreshRetryPeriod) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			// No error means we obtained the lock
 | ||||||
|  | 			lockObtained = true | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// If session was already refreshed via a concurrent request locked skip refreshing,
 | 	// The rest of this function is carried out under lock, but we must release it
 | ||||||
| 	// because the refreshed session is already loaded from storage
 | 	// wherever we exit from this function.
 | ||||||
| 	if !wasRefreshed { | 	defer func() { | ||||||
| 		logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) | 		if session == nil { | ||||||
| 		err = s.refreshSession(rw, req, session) | 			return | ||||||
| 		if err != nil { |  | ||||||
| 			// If a preemptive refresh fails, we still keep the session
 |  | ||||||
| 			// if validateSession succeeds.
 |  | ||||||
| 			logger.Errorf("Unable to refresh session: %v", err) |  | ||||||
| 		} | 		} | ||||||
|  | 		if err := session.ReleaseLock(req.Context()); err != nil { | ||||||
|  | 			logger.Errorf("unable to release lock: %v", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	// Reload the session in case it was changed underneath us.
 | ||||||
|  | 	freshSession, err := s.store.Load(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("could not load session: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if freshSession == nil { | ||||||
|  | 		return errors.New("session no longer exists, it may have been removed by another request") | ||||||
|  | 	} | ||||||
|  | 	// Restore the state of the fresh session into the original pointer.
 | ||||||
|  | 	// This is important so that changes are passed up the to the parent scope.
 | ||||||
|  | 	lock := session.Lock | ||||||
|  | 	*session = *freshSession | ||||||
|  | 
 | ||||||
|  | 	// Ensure we maintain the session lock after we have refreshed the session.
 | ||||||
|  | 	// Loading from the session store creates a new lock in the session.
 | ||||||
|  | 	session.Lock = lock | ||||||
|  | 
 | ||||||
|  | 	if !needsRefresh(s.refreshPeriod, session) { | ||||||
|  | 		// The session must have already been refreshed while we were waiting to
 | ||||||
|  | 		// obtain the lock.
 | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// We are holding the lock and the session needs a refresh
 | ||||||
|  | 	logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) | ||||||
|  | 	if err := s.refreshSession(rw, req, session); err != nil { | ||||||
|  | 		// If a preemptive refresh fails, we still keep the session
 | ||||||
|  | 		// if validateSession succeeds.
 | ||||||
|  | 		logger.Errorf("Unable to refresh session: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Validate all sessions after any Redeem/Refresh operation (fail or success)
 | 	// Validate all sessions after any Redeem/Refresh operation (fail or success)
 | ||||||
| 	return s.validateSession(req.Context(), session) | 	return s.validateSession(req.Context(), session) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // needsRefresh determines whether we should attempt to refresh a session or not.
 | ||||||
|  | func needsRefresh(refreshPeriod time.Duration, session *sessionsapi.SessionState) bool { | ||||||
|  | 	return refreshPeriod > time.Duration(0) && session.Age() > refreshPeriod | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // refreshSession attempts to refresh the session with the provider
 | // refreshSession attempts to refresh the session with the provider
 | ||||||
| // and will save the session if it was updated.
 | // and will save the session if it was updated.
 | ||||||
| func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||||
| 	err := session.ObtainLock(req.Context(), SessionLockExpireTime) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Errorf("Unable to obtain lock: %v", err) |  | ||||||
| 		return s.handleObtainLockError(req, session) |  | ||||||
| 	} |  | ||||||
| 	defer func() { |  | ||||||
| 		err = session.ReleaseLock(req.Context()) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Errorf("unable to release lock: %v", err) |  | ||||||
| 		} |  | ||||||
| 	}() |  | ||||||
| 
 |  | ||||||
| 	refreshed, err := s.sessionRefresher(req.Context(), session) | 	refreshed, err := s.sessionRefresher(req.Context(), session) | ||||||
| 	if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | 	if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | ||||||
| 		return fmt.Errorf("error refreshing tokens: %v", err) | 		return fmt.Errorf("error refreshing tokens: %v", err) | ||||||
|  | @ -182,75 +233,11 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R | ||||||
| 	err = s.store.Save(rw, req, session) | 	err = s.store.Save(rw, req, session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | ||||||
| 		err = fmt.Errorf("error saving session: %v", err) | 		return fmt.Errorf("error saving session: %v", err) | ||||||
| 	} |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *storedSessionLoader) handleObtainLockError(req *http.Request, session *sessionsapi.SessionState) error { |  | ||||||
| 	wasRefreshed, err := s.checkForConcurrentRefresh(session, req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Errorf("Unable to wait for obtained lock: %v", err) |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	if !wasRefreshed { |  | ||||||
| 		return errors.New("unable to obtain lock and session was also not refreshed via concurrent request") |  | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error { |  | ||||||
| 	sessionStored, err := s.store.Load(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("unable to load updated session from store: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if sessionStored == nil { |  | ||||||
| 		return fmt.Errorf("no session available to udpate from store") |  | ||||||
| 	} |  | ||||||
| 	*session = *sessionStored |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) { |  | ||||||
| 	var wasLocked bool |  | ||||||
| 	isLocked, err := session.PeekLock(req.Context()) |  | ||||||
| 	for isLocked { |  | ||||||
| 		wasLocked = true |  | ||||||
| 		// delay next peek lock
 |  | ||||||
| 		time.Sleep(SessionLockPeekDelay) |  | ||||||
| 		isLocked, err = session.PeekLock(req.Context()) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return wasLocked, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // checkForConcurrentRefresh returns true if the session is already refreshed via a concurrent request.
 |  | ||||||
| func (s *storedSessionLoader) checkForConcurrentRefresh(session *sessionsapi.SessionState, req *http.Request) (bool, error) { |  | ||||||
| 	wasLocked, err := s.waitForPossibleSessionLock(session, req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	refreshed := false |  | ||||||
| 	if wasLocked { |  | ||||||
| 		logger.Printf("Update session from store instead of refreshing") |  | ||||||
| 		err = s.updateSessionFromStore(req, session) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Errorf("Unable to update session from store: %v", err) |  | ||||||
| 			return false, err |  | ||||||
| 		} |  | ||||||
| 		refreshed = true |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return refreshed, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // validateSession checks whether the session has expired and performs
 | // validateSession checks whether the session has expired and performs
 | ||||||
| // provider validation on the session.
 | // provider validation on the session.
 | ||||||
| // An error implies the session is not longer valid.
 | // An error implies the session is not longer valid.
 | ||||||
|  |  | ||||||
|  | @ -18,97 +18,67 @@ import ( | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type TestLock struct { | type testLock struct { | ||||||
| 	Locked            bool | 	locked          bool | ||||||
| 	WasObtained       bool | 	obtainOnAttempt int | ||||||
| 	WasRefreshed      bool | 	obtainAttempts  int | ||||||
| 	WasReleased       bool | 	obtainError     error | ||||||
| 	PeekedCount       int |  | ||||||
| 	LockedOnPeekCount int |  | ||||||
| 	ObtainError       error |  | ||||||
| 	PeekError         error |  | ||||||
| 	RefreshError      error |  | ||||||
| 	ReleaseError      error |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (l *TestLock) Obtain(_ context.Context, _ time.Duration) error { | func (l *testLock) Obtain(_ context.Context, _ time.Duration) error { | ||||||
| 	if l.ObtainError != nil { | 	l.obtainAttempts++ | ||||||
| 		return l.ObtainError | 	if l.obtainAttempts < l.obtainOnAttempt { | ||||||
| 	} |  | ||||||
| 	l.Locked = true |  | ||||||
| 	l.WasObtained = true |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (l *TestLock) Peek(_ context.Context) (bool, error) { |  | ||||||
| 	if l.PeekError != nil { |  | ||||||
| 		return false, l.PeekError |  | ||||||
| 	} |  | ||||||
| 	locked := l.Locked |  | ||||||
| 	l.Locked = false |  | ||||||
| 	l.PeekedCount++ |  | ||||||
| 	// mainly used to test case when peek initially returns false,
 |  | ||||||
| 	// but when trying to obtain lock, it returns true.
 |  | ||||||
| 	if l.LockedOnPeekCount == l.PeekedCount { |  | ||||||
| 		return true, nil |  | ||||||
| 	} |  | ||||||
| 	return locked, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (l *TestLock) Refresh(_ context.Context, _ time.Duration) error { |  | ||||||
| 	if l.RefreshError != nil { |  | ||||||
| 		return l.ReleaseError |  | ||||||
| 	} |  | ||||||
| 	l.WasRefreshed = true |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (l *TestLock) Release(_ context.Context) error { |  | ||||||
| 	if l.ReleaseError != nil { |  | ||||||
| 		return l.ReleaseError |  | ||||||
| 	} |  | ||||||
| 	l.Locked = false |  | ||||||
| 	l.WasReleased = true |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type LockConc struct { |  | ||||||
| 	mu          sync.Mutex |  | ||||||
| 	lock        bool |  | ||||||
| 	disablePeek bool |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error { |  | ||||||
| 	l.mu.Lock() |  | ||||||
| 	if l.lock { |  | ||||||
| 		l.mu.Unlock() |  | ||||||
| 		return sessionsapi.ErrLockNotObtained | 		return sessionsapi.ErrLockNotObtained | ||||||
| 	} | 	} | ||||||
| 	l.lock = true | 	if l.obtainError != nil { | ||||||
| 	l.mu.Unlock() | 		return l.obtainError | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (l *LockConc) Peek(_ context.Context) (bool, error) { |  | ||||||
| 	var response bool |  | ||||||
| 	l.mu.Lock() |  | ||||||
| 	if l.disablePeek { |  | ||||||
| 		response = false |  | ||||||
| 	} else { |  | ||||||
| 		response = l.lock |  | ||||||
| 	} | 	} | ||||||
| 	l.mu.Unlock() | 	l.locked = true | ||||||
| 	return response, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error { |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (l *LockConc) Release(_ context.Context) error { | func (l *testLock) Peek(_ context.Context) (bool, error) { | ||||||
|  | 	return l.locked, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *testLock) Refresh(_ context.Context, _ time.Duration) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *testLock) Release(_ context.Context) error { | ||||||
|  | 	l.locked = false | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type testLockConcurrent struct { | ||||||
|  | 	mu     sync.RWMutex | ||||||
|  | 	locked bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *testLockConcurrent) Obtain(_ context.Context, _ time.Duration) error { | ||||||
| 	l.mu.Lock() | 	l.mu.Lock() | ||||||
| 	l.lock = false | 	defer l.mu.Unlock() | ||||||
| 	l.mu.Unlock() | 	if l.locked { | ||||||
|  | 		return sessionsapi.ErrLockNotObtained | ||||||
|  | 	} | ||||||
|  | 	l.locked = true | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *testLockConcurrent) Peek(_ context.Context) (bool, error) { | ||||||
|  | 	l.mu.RLock() | ||||||
|  | 	defer l.mu.RUnlock() | ||||||
|  | 	return l.locked, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *testLockConcurrent) Refresh(_ context.Context, _ time.Duration) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *testLockConcurrent) Release(_ context.Context) error { | ||||||
|  | 	l.mu.Lock() | ||||||
|  | 	defer l.mu.Unlock() | ||||||
|  | 	l.locked = false | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -374,22 +344,29 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 
 | 
 | ||||||
| 		DescribeTable("when serving concurrent requests", | 		DescribeTable("when serving concurrent requests", | ||||||
| 			func(in storedSessionLoaderConcurrentTableInput) { | 			func(in storedSessionLoaderConcurrentTableInput) { | ||||||
| 				lockConc := &LockConc{} | 				lockConc := &testLockConcurrent{} | ||||||
|  | 
 | ||||||
|  | 				lock := &sync.RWMutex{} | ||||||
|  | 				existingSession := *in.existingSession // deep copy existingSession state
 | ||||||
|  | 				existingSession.Lock = lockConc | ||||||
|  | 				store := &fakeSessionStore{ | ||||||
|  | 					LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 						lock.RLock() | ||||||
|  | 						defer lock.RUnlock() | ||||||
|  | 						session := existingSession | ||||||
|  | 						return &session, nil | ||||||
|  | 					}, | ||||||
|  | 					SaveFunc: func(_ http.ResponseWriter, _ *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 						lock.Lock() | ||||||
|  | 						defer lock.Unlock() | ||||||
|  | 						existingSession = *session | ||||||
|  | 						return nil | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
| 
 | 
 | ||||||
| 				refreshedChan := make(chan bool, in.numConcReqs) | 				refreshedChan := make(chan bool, in.numConcReqs) | ||||||
| 				for i := 0; i < in.numConcReqs; i++ { | 				for i := 0; i < in.numConcReqs; i++ { | ||||||
| 					go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { | 					go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { | ||||||
| 						existingSession := *in.existingSession // deep copy existingSession state
 |  | ||||||
| 						existingSession.Lock = lockConc |  | ||||||
| 						store := &fakeSessionStore{ |  | ||||||
| 							LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { |  | ||||||
| 								return &existingSession, nil |  | ||||||
| 							}, |  | ||||||
| 							SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error { |  | ||||||
| 								return nil |  | ||||||
| 							}, |  | ||||||
| 						} |  | ||||||
| 
 |  | ||||||
| 						scope := &middlewareapi.RequestScope{ | 						scope := &middlewareapi.RequestScope{ | ||||||
| 							Session: nil, | 							Session: nil, | ||||||
| 						} | 						} | ||||||
|  | @ -461,13 +438,13 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 
 | 
 | ||||||
| 	Context("refreshSessionIfNeeded", func() { | 	Context("refreshSessionIfNeeded", func() { | ||||||
| 		type refreshSessionIfNeededTableInput struct { | 		type refreshSessionIfNeededTableInput struct { | ||||||
| 			refreshPeriod     time.Duration | 			refreshPeriod            time.Duration | ||||||
| 			sessionStored     bool | 			session                  *sessionsapi.SessionState | ||||||
| 			session           *sessionsapi.SessionState | 			concurrentSessionRefresh bool | ||||||
| 			expectedErr       error | 			expectedErr              error | ||||||
| 			expectRefreshed   bool | 			expectRefreshed          bool | ||||||
| 			expectValidated   bool | 			expectValidated          bool | ||||||
| 			expectedLockState TestLock | 			expectedLockObtained     bool | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		createdPast := time.Now().Add(-5 * time.Minute) | 		createdPast := time.Now().Add(-5 * time.Minute) | ||||||
|  | @ -478,13 +455,23 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				refreshed := false | 				refreshed := false | ||||||
| 				validated := false | 				validated := false | ||||||
| 
 | 
 | ||||||
| 				store := &fakeSessionStore{} | 				session := &sessionsapi.SessionState{} | ||||||
| 				if in.sessionStored { | 				*session = *in.session | ||||||
| 					store = &fakeSessionStore{ | 				if in.concurrentSessionRefresh { | ||||||
| 						LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | 					// Update the session that Load returns.
 | ||||||
| 							return in.session, nil | 					// This simulates a concurrent refresh in the background.
 | ||||||
| 						}, | 					session.CreatedAt = &createdFuture | ||||||
| 					} | 				} | ||||||
|  | 				store := &fakeSessionStore{ | ||||||
|  | 					LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 						// Loading the session from the provider creates a new lock
 | ||||||
|  | 						session.Lock = &testLock{} | ||||||
|  | 						return session, nil | ||||||
|  | 					}, | ||||||
|  | 					SaveFunc: func(_ http.ResponseWriter, _ *http.Request, s *sessionsapi.SessionState) error { | ||||||
|  | 						*session = *s | ||||||
|  | 						return nil | ||||||
|  | 					}, | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				s := &storedSessionLoader{ | 				s := &storedSessionLoader{ | ||||||
|  | @ -518,117 +505,90 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				} | 				} | ||||||
| 				Expect(refreshed).To(Equal(in.expectRefreshed)) | 				Expect(refreshed).To(Equal(in.expectRefreshed)) | ||||||
| 				Expect(validated).To(Equal(in.expectValidated)) | 				Expect(validated).To(Equal(in.expectValidated)) | ||||||
| 				testLock, ok := in.session.Lock.(*TestLock) | 				testLock, ok := in.session.Lock.(*testLock) | ||||||
| 				Expect(ok).To(Equal(true)) | 				Expect(ok).To(Equal(true)) | ||||||
| 
 | 
 | ||||||
| 				Expect(testLock).To(Equal(&in.expectedLockState)) | 				if in.expectedLockObtained { | ||||||
|  | 					Expect(testLock.obtainAttempts).Should(BeNumerically(">", 0), "Expected at least one attempt at obtaining the session lock") | ||||||
|  | 				} | ||||||
|  | 				Expect(testLock.locked).To(BeFalse(), "Expected lock should always be released") | ||||||
| 			}, | 			}, | ||||||
| 			Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ | 			Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: time.Duration(0), | 				refreshPeriod: time.Duration(0), | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 					CreatedAt:    &createdFuture, | 					CreatedAt:    &createdFuture, | ||||||
| 					Lock:         &TestLock{}, | 					Lock:         &testLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:       nil, | 				expectedErr:          nil, | ||||||
| 				expectRefreshed:   false, | 				expectRefreshed:      false, | ||||||
| 				expectValidated:   false, | 				expectValidated:      false, | ||||||
| 				expectedLockState: TestLock{}, | 				expectedLockObtained: false, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ | 			Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: time.Duration(0), | 				refreshPeriod: time.Duration(0), | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					Lock:         &TestLock{}, | 					Lock:         &testLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:       nil, | 				expectedErr:          nil, | ||||||
| 				expectRefreshed:   false, | 				expectRefreshed:      false, | ||||||
| 				expectValidated:   false, | 				expectValidated:      false, | ||||||
| 				expectedLockState: TestLock{}, | 				expectedLockObtained: false, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ | 			Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 					CreatedAt:    &createdFuture, | 					CreatedAt:    &createdFuture, | ||||||
| 					Lock:         &TestLock{}, | 					Lock:         &testLock{}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:       nil, | 				expectedErr:          nil, | ||||||
| 				expectRefreshed:   false, | 				expectRefreshed:      false, | ||||||
| 				expectValidated:   false, | 				expectValidated:      false, | ||||||
| 				expectedLockState: TestLock{}, | 				expectedLockObtained: false, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ | 			Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					Lock:         &TestLock{}, | 					Lock:         &testLock{}, | ||||||
| 				}, |  | ||||||
| 				expectedErr:     nil, |  | ||||||
| 				expectRefreshed: true, |  | ||||||
| 				expectValidated: true, |  | ||||||
| 				expectedLockState: TestLock{ |  | ||||||
| 					Locked:      false, |  | ||||||
| 					WasObtained: true, |  | ||||||
| 					WasReleased: true, |  | ||||||
| 					PeekedCount: 1, |  | ||||||
| 				}, |  | ||||||
| 			}), |  | ||||||
| 			Entry("when the session is locked and instead loaded from storage", refreshSessionIfNeededTableInput{ |  | ||||||
| 				refreshPeriod: 1 * time.Minute, |  | ||||||
| 				session: &sessionsapi.SessionState{ |  | ||||||
| 					RefreshToken: noRefresh, |  | ||||||
| 					CreatedAt:    &createdPast, |  | ||||||
| 					Lock: &TestLock{ |  | ||||||
| 						Locked: true, |  | ||||||
| 					}, |  | ||||||
| 				}, |  | ||||||
| 				sessionStored:   true, |  | ||||||
| 				expectedErr:     nil, |  | ||||||
| 				expectRefreshed: false, |  | ||||||
| 				expectValidated: true, |  | ||||||
| 				expectedLockState: TestLock{ |  | ||||||
| 					Locked:      false, |  | ||||||
| 					PeekedCount: 2, |  | ||||||
| 				}, | 				}, | ||||||
|  | 				expectedErr:          nil, | ||||||
|  | 				expectRefreshed:      true, | ||||||
|  | 				expectValidated:      true, | ||||||
|  | 				expectedLockObtained: true, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ | 			Entry("when obtaining lock failed, but concurrent request refreshed", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					Lock: &TestLock{ | 					Lock: &testLock{ | ||||||
| 						ObtainError:       errors.New("not able to obtain lock"), | 						obtainOnAttempt: 4, | ||||||
| 						LockedOnPeekCount: 2, |  | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				concurrentSessionRefresh: true, | ||||||
| 				expectRefreshed: false, | 				expectedErr:              nil, | ||||||
| 				expectValidated: true, | 				expectRefreshed:          false, | ||||||
| 				expectedLockState: TestLock{ | 				expectValidated:          false, | ||||||
| 					PeekedCount:       3, | 				expectedLockObtained:     true, | ||||||
| 					LockedOnPeekCount: 2, |  | ||||||
| 					ObtainError:       errors.New("not able to obtain lock"), |  | ||||||
| 				}, |  | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when obtaining lock failed", refreshSessionIfNeededTableInput{ | 			Entry("when obtaining lock failed with a valid session", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					Lock: &TestLock{ | 					Lock: &testLock{ | ||||||
| 						ObtainError: errors.New("not able to obtain lock"), | 						obtainError: sessionsapi.ErrLockNotObtained, | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:          errors.New("timeout obtaining session lock"), | ||||||
| 				expectRefreshed: false, | 				expectRefreshed:      false, | ||||||
| 				expectValidated: true, | 				expectValidated:      false, | ||||||
| 				expectedLockState: TestLock{ | 				expectedLockObtained: true, | ||||||
| 					PeekedCount: 2, |  | ||||||
| 					ObtainError: errors.New("not able to obtain lock"), |  | ||||||
| 				}, |  | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | @ -636,34 +596,24 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					ExpiresOn:    &createdFuture, | 					ExpiresOn:    &createdFuture, | ||||||
| 					Lock:         &TestLock{}, | 					Lock:         &testLock{}, | ||||||
| 				}, |  | ||||||
| 				expectedErr:     nil, |  | ||||||
| 				expectRefreshed: true, |  | ||||||
| 				expectValidated: true, |  | ||||||
| 				expectedLockState: TestLock{ |  | ||||||
| 					Locked:      false, |  | ||||||
| 					WasObtained: true, |  | ||||||
| 					WasReleased: true, |  | ||||||
| 					PeekedCount: 1, |  | ||||||
| 				}, | 				}, | ||||||
|  | 				expectedErr:          nil, | ||||||
|  | 				expectRefreshed:      true, | ||||||
|  | 				expectValidated:      true, | ||||||
|  | 				expectedLockObtained: true, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{ | 			Entry("when the provider doesn't implement refresh", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: notImplemented, | 					RefreshToken: notImplemented, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					Lock:         &TestLock{}, | 					Lock:         &testLock{}, | ||||||
| 				}, |  | ||||||
| 				expectedErr:     nil, |  | ||||||
| 				expectRefreshed: true, |  | ||||||
| 				expectValidated: true, |  | ||||||
| 				expectedLockState: TestLock{ |  | ||||||
| 					Locked:      false, |  | ||||||
| 					WasObtained: true, |  | ||||||
| 					WasReleased: true, |  | ||||||
| 					PeekedCount: 1, |  | ||||||
| 				}, | 				}, | ||||||
|  | 				expectedErr:          nil, | ||||||
|  | 				expectRefreshed:      true, | ||||||
|  | 				expectValidated:      true, | ||||||
|  | 				expectedLockObtained: true, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | @ -672,17 +622,12 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					ExpiresOn:    &createdFuture, | 					ExpiresOn:    &createdFuture, | ||||||
| 					Lock:         &TestLock{}, | 					Lock:         &testLock{}, | ||||||
| 				}, |  | ||||||
| 				expectedErr:     errors.New("session is invalid"), |  | ||||||
| 				expectRefreshed: true, |  | ||||||
| 				expectValidated: true, |  | ||||||
| 				expectedLockState: TestLock{ |  | ||||||
| 					Locked:      false, |  | ||||||
| 					WasObtained: true, |  | ||||||
| 					WasReleased: true, |  | ||||||
| 					PeekedCount: 1, |  | ||||||
| 				}, | 				}, | ||||||
|  | 				expectedErr:          errors.New("session is invalid"), | ||||||
|  | 				expectRefreshed:      true, | ||||||
|  | 				expectValidated:      true, | ||||||
|  | 				expectedLockObtained: true, | ||||||
| 			}), | 			}), | ||||||
| 		) | 		) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue