Refactor StoredSessionHandler
This commit is contained in:
		
							parent
							
								
									518e619289
								
							
						
					
					
						commit
						86ba2f41ce
					
				|  | @ -109,6 +109,12 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h | ||||||
| 		return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) | 		return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Validate all sessions after any Redeem/Refresh operation (fail or success)
 | ||||||
|  | 	err = s.validateSession(req.Context(), session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return session, nil | 	return session, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -121,36 +127,35 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var wasLocked bool | 	wasLocked, err := s.waitForPossibleSessionLock(session, req) | ||||||
| 	var err error |  | ||||||
| 	var isLocked bool |  | ||||||
| 	for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) { |  | ||||||
| 		wasLocked = true |  | ||||||
| 		// delay next peek lock
 |  | ||||||
| 		time.Sleep(SessionLockPeekDelay) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// If session was locked fetch current state
 | 	// If session was locked, fetch current state, because
 | ||||||
|  | 	// it should be updated after lock is released.
 | ||||||
| 	if wasLocked { | 	if wasLocked { | ||||||
| 		var sessionStored *sessionsapi.SessionState | 		err = s.updateSessionFromStore(req, session) | ||||||
| 		sessionStored, err = s.store.Load(req) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			logger.Errorf("Unable to load updated session from store: %v", err) | ||||||
| 		} | 		} | ||||||
| 
 | 		return err | ||||||
| 		if session == nil || sessionStored == nil { |  | ||||||
| 			return nil |  | ||||||
| 		} |  | ||||||
| 		*session = *sessionStored |  | ||||||
| 
 |  | ||||||
| 		return nil |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = session.ObtainLock(req.Context(), SessionLockExpireTime) | 	logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) | ||||||
|  | 	err = s.refreshSession(rw, req, session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// If a preemptive refresh fails, we still keep the session
 | ||||||
|  | 		// if validateSession succeeds.
 | ||||||
|  | 		logger.Errorf("Unable to refresh session: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // refreshSession attempts to refresh the session with the provider
 | ||||||
|  | // and will save the session if it was updated.
 | ||||||
|  | func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 	err := session.ObtainLock(req.Context(), SessionLockExpireTime) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("unable to obtain lock (skipping refresh): %v", err) | 		logger.Errorf("unable to obtain lock (skipping refresh): %v", err) | ||||||
| 		return nil | 		return nil | ||||||
|  | @ -162,21 +167,6 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) |  | ||||||
| 	err = s.refreshSession(rw, req, session) |  | ||||||
| 	if err != nil { |  | ||||||
| 		// If a preemptive refresh fails, we still keep the session
 |  | ||||||
| 		// if validateSession succeeds.
 |  | ||||||
| 		logger.Errorf("Unable to refresh session: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Validate all sessions after any Redeem/Refresh operation (fail or success)
 |  | ||||||
| 	return s.validateSession(req.Context(), session) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // refreshSession attempts to refresh the session with the provider
 |  | ||||||
| // and will save the session if it was updated.
 |  | ||||||
| func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { |  | ||||||
| 	refreshed, err := s.sessionRefresher(req.Context(), session) | 	refreshed, err := s.sessionRefresher(req.Context(), session) | ||||||
| 	if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | 	if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | ||||||
| 		return fmt.Errorf("error refreshing tokens: %v", err) | 		return fmt.Errorf("error refreshing tokens: %v", err) | ||||||
|  | @ -205,11 +195,42 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R | ||||||
| 	err = s.store.Save(rw, req, session) | 	err = s.store.Save(rw, req, session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | ||||||
| 		return fmt.Errorf("error saving session: %v", err) | 		err = fmt.Errorf("error saving session: %v", err) | ||||||
| 	} | 	} | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *storedSessionLoader) updateSessionFromStore(req *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 	sessionStored, err := s.store.Load(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if session == nil || sessionStored == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	*session = *sessionStored | ||||||
|  | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (s *storedSessionLoader) waitForPossibleSessionLock(session *sessionsapi.SessionState, req *http.Request) (bool, error) { | ||||||
|  | 	var wasLocked bool | ||||||
|  | 	var err error | ||||||
|  | 	var isLocked bool | ||||||
|  | 	for isLocked, err = session.PeekLock(req.Context()); isLocked; isLocked, err = session.PeekLock(req.Context()) { | ||||||
|  | 		wasLocked = true | ||||||
|  | 		// delay next peek lock
 | ||||||
|  | 		time.Sleep(SessionLockPeekDelay) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return wasLocked, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // validateSession checks whether the session has expired and performs
 | // validateSession checks whether the session has expired and performs
 | ||||||
| // provider validation on the session.
 | // provider validation on the session.
 | ||||||
| // An error implies the session is not longer valid.
 | // An error implies the session is not longer valid.
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue