RefreshSessions immediately when called
This commit is contained in:
		
							parent
							
								
									5f4ac25b1e
								
							
						
					
					
						commit
						7e80e5596b
					
				|  | @ -363,8 +363,8 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt | ||||||
| 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ | 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ | ||||||
| 		SessionStore:    sessionStore, | 		SessionStore:    sessionStore, | ||||||
| 		RefreshPeriod:   opts.Cookie.Refresh, | 		RefreshPeriod:   opts.Cookie.Refresh, | ||||||
| 		RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, | 		RefreshSession:  opts.GetProvider().RefreshSession, | ||||||
| 		ValidateSessionState:   opts.GetProvider().ValidateSession, | 		ValidateSession: opts.GetProvider().ValidateSession, | ||||||
| 	})) | 	})) | ||||||
| 
 | 
 | ||||||
| 	return chain | 	return chain | ||||||
|  |  | ||||||
|  | @ -24,12 +24,12 @@ type StoredSessionLoaderOptions struct { | ||||||
| 	RefreshPeriod time.Duration | 	RefreshPeriod time.Duration | ||||||
| 
 | 
 | ||||||
| 	// Provider based sesssion refreshing
 | 	// Provider based sesssion refreshing
 | ||||||
| 	RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | 	RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||||
| 
 | 
 | ||||||
| 	// Provider based session validation.
 | 	// Provider based session validation.
 | ||||||
| 	// If the sesssion is older than `RefreshPeriod` but the provider doesn't
 | 	// If the sesssion is older than `RefreshPeriod` but the provider doesn't
 | ||||||
| 	// refresh it, we must re-validate using this validation.
 | 	// refresh it, we must re-validate using this validation.
 | ||||||
| 	ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool | 	ValidateSession func(context.Context, *sessionsapi.SessionState) bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewStoredSessionLoader creates a new storedSessionLoader which loads
 | // NewStoredSessionLoader creates a new storedSessionLoader which loads
 | ||||||
|  | @ -40,8 +40,8 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor | ||||||
| 	ss := &storedSessionLoader{ | 	ss := &storedSessionLoader{ | ||||||
| 		store:            opts.SessionStore, | 		store:            opts.SessionStore, | ||||||
| 		refreshPeriod:    opts.RefreshPeriod, | 		refreshPeriod:    opts.RefreshPeriod, | ||||||
| 		refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, | 		sessionRefresher: opts.RefreshSession, | ||||||
| 		validateSessionState:               opts.ValidateSessionState, | 		sessionValidator: opts.ValidateSession, | ||||||
| 	} | 	} | ||||||
| 	return ss.loadSession | 	return ss.loadSession | ||||||
| } | } | ||||||
|  | @ -51,8 +51,8 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor | ||||||
| type storedSessionLoader struct { | type storedSessionLoader struct { | ||||||
| 	store            sessionsapi.SessionStore | 	store            sessionsapi.SessionStore | ||||||
| 	refreshPeriod    time.Duration | 	refreshPeriod    time.Duration | ||||||
| 	refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | 	sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||||
| 	validateSessionState               func(context.Context, *sessionsapi.SessionState) bool | 	sessionValidator func(context.Context, *sessionsapi.SessionState) bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // loadSession attempts to load a session as identified by the request cookies.
 | // loadSession attempts to load a session as identified by the request cookies.
 | ||||||
|  | @ -120,37 +120,38 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) | 	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) | ||||||
| 	refreshed, err := s.refreshSessionWithProvider(rw, req, session) | 	err := s.refreshSession(rw, req, session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !refreshed { | 	// Validate all sessions after any Redeem/Refresh operation
 | ||||||
| 		// Session wasn't refreshed, so make sure it's still valid
 |  | ||||||
| 	return s.validateSession(req.Context(), session) | 	return s.validateSession(req.Context(), session) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // refreshSession attempts to refresh the session with the provider
 | ||||||
|  | // and will save the session if it was updated.
 | ||||||
|  | func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 	refreshed, err := s.sessionRefresher(req.Context(), session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("error refreshing access token: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !refreshed { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| // refreshSessionWithProvider attempts to refresh the sessinon with the provider
 | 	// If we refreshed, update the `CreatedAt` time to reset the refresh timer
 | ||||||
| // and will save the session if it was updated.
 | 	// TODO: Implement
 | ||||||
| func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { | 	// session.CreatedAtNow()
 | ||||||
| 	refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, fmt.Errorf("error refreshing access token: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if !refreshed { |  | ||||||
| 		return false, nil |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	// Because the session was refreshed, make sure to save it
 | 	// Because the session was refreshed, make sure to save it
 | ||||||
| 	err = s.store.Save(rw, req, session) | 	err = s.store.Save(rw, req, session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | ||||||
| 		return false, fmt.Errorf("error saving session: %v", err) | 		return fmt.Errorf("error saving session: %v", err) | ||||||
| 	} | 	} | ||||||
| 	return true, nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // validateSession checks whether the session has expired and performs
 | // validateSession checks whether the session has expired and performs
 | ||||||
|  | @ -161,7 +162,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess | ||||||
| 		return errors.New("session is expired") | 		return errors.New("session is expired") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !s.validateSessionState(ctx, session) { | 	if !s.sessionValidator(ctx, session) { | ||||||
| 		return errors.New("session is invalid") | 		return errors.New("session is invalid") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -111,8 +111,8 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				opts := &StoredSessionLoaderOptions{ | 				opts := &StoredSessionLoaderOptions{ | ||||||
| 					SessionStore:    in.store, | 					SessionStore:    in.store, | ||||||
| 					RefreshPeriod:   in.refreshPeriod, | 					RefreshPeriod:   in.refreshPeriod, | ||||||
| 					RefreshSessionIfNeeded: in.refreshSession, | 					RefreshSession:  in.refreshSession, | ||||||
| 					ValidateSessionState:   in.validateSession, | 					ValidateSession: in.validateSession, | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				// Create the handler with a next handler that will capture the session
 | 				// Create the handler with a next handler that will capture the session
 | ||||||
|  | @ -261,7 +261,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				s := &storedSessionLoader{ | 				s := &storedSessionLoader{ | ||||||
| 					refreshPeriod: in.refreshPeriod, | 					refreshPeriod: in.refreshPeriod, | ||||||
| 					store:         &fakeSessionStore{}, | 					store:         &fakeSessionStore{}, | ||||||
| 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | 					sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
| 						refreshed = true | 						refreshed = true | ||||||
| 						switch ss.RefreshToken { | 						switch ss.RefreshToken { | ||||||
| 						case refresh: | 						case refresh: | ||||||
|  | @ -272,7 +272,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 							return false, errors.New("error refreshing session") | 							return false, errors.New("error refreshing session") | ||||||
| 						} | 						} | ||||||
| 					}, | 					}, | ||||||
| 					validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | 					sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||||
| 						validated = true | 						validated = true | ||||||
| 						return ss.AccessToken != "Invalid" | 						return ss.AccessToken != "Invalid" | ||||||
| 					}, | 					}, | ||||||
|  | @ -364,7 +364,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 		) | 		) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	Context("refreshSessionWithProvider", func() { | 	Context("refreshSession", func() { | ||||||
| 		type refreshSessionWithProviderTableInput struct { | 		type refreshSessionWithProviderTableInput struct { | ||||||
| 			session         *sessionsapi.SessionState | 			session         *sessionsapi.SessionState | ||||||
| 			expectedErr     error | 			expectedErr     error | ||||||
|  | @ -388,7 +388,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 							return nil | 							return nil | ||||||
| 						}, | 						}, | ||||||
| 					}, | 					}, | ||||||
| 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | 					sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
| 						switch ss.RefreshToken { | 						switch ss.RefreshToken { | ||||||
| 						case refresh: | 						case refresh: | ||||||
| 							return true, nil | 							return true, nil | ||||||
|  | @ -402,13 +402,12 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 
 | 
 | ||||||
| 				req := httptest.NewRequest("", "/", nil) | 				req := httptest.NewRequest("", "/", nil) | ||||||
| 				req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | 				req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||||
| 				refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) | 				err := s.refreshSession(nil, req, in.session) | ||||||
| 				if in.expectedErr != nil { | 				if in.expectedErr != nil { | ||||||
| 					Expect(err).To(MatchError(in.expectedErr)) | 					Expect(err).To(MatchError(in.expectedErr)) | ||||||
| 				} else { | 				} else { | ||||||
| 					Expect(err).ToNot(HaveOccurred()) | 					Expect(err).ToNot(HaveOccurred()) | ||||||
| 				} | 				} | ||||||
| 				Expect(refreshed).To(Equal(in.expectRefreshed)) |  | ||||||
| 				Expect(saved).To(Equal(in.expectSaved)) | 				Expect(saved).To(Equal(in.expectSaved)) | ||||||
| 			}, | 			}, | ||||||
| 			Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ | 			Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ | ||||||
|  | @ -416,7 +415,6 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:     nil, | ||||||
| 				expectRefreshed: false, |  | ||||||
| 				expectSaved:     false, | 				expectSaved:     false, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ | 			Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ | ||||||
|  | @ -424,7 +422,6 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:     nil, | ||||||
| 				expectRefreshed: true, |  | ||||||
| 				expectSaved:     true, | 				expectSaved:     true, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ | 			Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ | ||||||
|  | @ -434,7 +431,6 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					ExpiresOn:    &now, | 					ExpiresOn:    &now, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | ||||||
| 				expectRefreshed: false, |  | ||||||
| 				expectSaved:     false, | 				expectSaved:     false, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ | 			Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ | ||||||
|  | @ -443,7 +439,6 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					AccessToken:  "NoSave", | 					AccessToken:  "NoSave", | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     errors.New("error saving session: unable to save session"), | 				expectedErr:     errors.New("error saving session: unable to save session"), | ||||||
| 				expectRefreshed: false, |  | ||||||
| 				expectSaved:     true, | 				expectSaved:     true, | ||||||
| 			}), | 			}), | ||||||
| 		) | 		) | ||||||
|  | @ -454,7 +449,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 
 | 
 | ||||||
| 		BeforeEach(func() { | 		BeforeEach(func() { | ||||||
| 			s = &storedSessionLoader{ | 			s = &storedSessionLoader{ | ||||||
| 				validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | 				sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||||
| 					return ss.AccessToken == "Valid" | 					return ss.AccessToken == "Valid" | ||||||
| 				}, | 				}, | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | @ -345,7 +345,7 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	expires := time.Now().Add(time.Duration(1) * time.Hour) | 	expires := time.Now().Add(time.Duration(1) * time.Hour) | ||||||
| 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | ||||||
| 	refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) | 	refreshNeeded, err := p.RefreshSession(context.Background(), session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.False(t, refreshNeeded) | 	assert.False(t, refreshNeeded) | ||||||
| } | } | ||||||
|  | @ -373,9 +373,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	expires := time.Now().Add(time.Duration(-1) * time.Hour) | 	expires := time.Now().Add(time.Duration(-1) * time.Hour) | ||||||
| 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | ||||||
| 	refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) | 
 | ||||||
|  | 	refreshed, err := p.RefreshSession(context.Background(), session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.True(t, refreshNeeded) | 	assert.True(t, refreshed) | ||||||
| 	assert.NotEqual(t, session, nil) | 	assert.NotEqual(t, session, nil) | ||||||
| 	assert.Equal(t, "new_some_access_token", session.AccessToken) | 	assert.Equal(t, "new_some_access_token", session.AccessToken) | ||||||
| 	assert.Equal(t, "new_some_refresh_token", session.RefreshToken) | 	assert.Equal(t, "new_some_refresh_token", session.RefreshToken) | ||||||
|  |  | ||||||
|  | @ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	return r.Email, nil | 	return r.Email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSession validates the AccessToken
 | ||||||
| func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) | 	return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
| func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | 	if s == nil || s.RefreshToken == "" { | ||||||
| 	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { |  | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { | func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { | ||||||
| 	clientSecret, err := p.GetClientSecret() | 	clientSecret, err := p.GetClientSecret() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	c := oauth2.Config{ | 	c := oauth2.Config{ | ||||||
|  | @ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("unable to update session: %v", err) | 		return fmt.Errorf("unable to update session: %v", err) | ||||||
| 	} | 	} | ||||||
| 	s.AccessToken = newSession.AccessToken | 	*s = *newSession | ||||||
| 	s.IDToken = newSession.IDToken | 
 | ||||||
| 	s.RefreshToken = newSession.RefreshToken | 	return nil | ||||||
| 	s.CreatedAt = newSession.CreatedAt |  | ||||||
| 	s.ExpiresOn = newSession.ExpiresOn |  | ||||||
| 	s.Email = newSession.Email |  | ||||||
| 	return |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type gitlabUserInfo struct { | type gitlabUserInfo struct { | ||||||
|  |  | ||||||
|  | @ -266,10 +266,9 @@ func userInGroup(service *admin.Service, group string, email string) bool { | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | 	if s == nil || s.RefreshToken == "" { | ||||||
| 	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { |  | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	return email, nil | 	return email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSession validates the AccessToken
 | ||||||
| func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) | 	return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS | ||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | ||||||
| // RefreshToken to fetch a new Access Token (and optional ID token) if required
 | func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | 	if s == nil || s.RefreshToken == "" { | ||||||
| 	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { |  | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { | ||||||
| 		User:         "11223344", | 		User:         "11223344", | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) | 	refreshed, err := provider.RefreshSession(context.Background(), existingSession) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, refreshed, true) | 	assert.Equal(t, refreshed, true) | ||||||
| 	assert.Equal(t, "janedoe@example.com", existingSession.Email) | 	assert.Equal(t, "janedoe@example.com", existingSession.Email) | ||||||
|  | @ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { | ||||||
| 		Email:        "changeit", | 		Email:        "changeit", | ||||||
| 		User:         "changeit", | 		User:         "changeit", | ||||||
| 	} | 	} | ||||||
| 	refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) | 	refreshed, err := provider.RefreshSession(context.Background(), existingSession) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, refreshed, true) | 	assert.Equal(t, refreshed, true) | ||||||
| 	assert.Equal(t, defaultIDToken.Email, existingSession.Email) | 	assert.Equal(t, defaultIDToken.Email, existingSession.Email) | ||||||
|  |  | ||||||
|  | @ -126,12 +126,17 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS | ||||||
| 	return validateToken(ctx, p, s.AccessToken, nil) | 	return validateToken(ctx, p, s.AccessToken, nil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded should refresh the user's session if required and
 | // RefreshSession refreshes the user's session
 | ||||||
| // do nothing if a refresh is not required
 | func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||||
| func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { | 	if s == nil { | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Pretend `RefreshSession` occured so `ValidateSession` isn't called
 | ||||||
|  | 	// on every request after any potential set refresh period elapses.
 | ||||||
|  | 	return true, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // CreateSessionFromToken converts Bearer IDTokens into sessions
 | // CreateSessionFromToken converts Bearer IDTokens into sessions
 | ||||||
| func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | ||||||
| 	if p.Verifier != nil { | 	if p.Verifier != nil { | ||||||
|  |  | ||||||
|  | @ -15,7 +15,7 @@ func TestRefresh(t *testing.T) { | ||||||
| 	p := &ProviderData{} | 	p := &ProviderData{} | ||||||
| 
 | 
 | ||||||
| 	expires := time.Now().Add(time.Duration(-11) * time.Minute) | 	expires := time.Now().Add(time.Duration(-11) * time.Minute) | ||||||
| 	refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ | 	refreshed, err := p.RefreshSession(context.Background(), &sessions.SessionState{ | ||||||
| 		ExpiresOn: &expires, | 		ExpiresOn: &expires, | ||||||
| 	}) | 	}) | ||||||
| 	assert.Equal(t, false, refreshed) | 	assert.Equal(t, false, refreshed) | ||||||
|  |  | ||||||
|  | @ -9,14 +9,14 @@ import ( | ||||||
| // Provider represents an upstream identity provider implementation
 | // Provider represents an upstream identity provider implementation
 | ||||||
| type Provider interface { | type Provider interface { | ||||||
| 	Data() *ProviderData | 	Data() *ProviderData | ||||||
|  | 	GetLoginURL(redirectURI, finalRedirect string, nonce string) string | ||||||
|  | 	Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) | ||||||
| 	// Deprecated: Migrate to EnrichSession
 | 	// Deprecated: Migrate to EnrichSession
 | ||||||
| 	GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) | 	GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) | ||||||
| 	GetLoginURL(redirectURI, state, nonce string) string |  | ||||||
| 	Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) |  | ||||||
| 	EnrichSession(ctx context.Context, s *sessions.SessionState) error | 	EnrichSession(ctx context.Context, s *sessions.SessionState) error | ||||||
| 	Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) | 	Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||||
| 	ValidateSession(ctx context.Context, s *sessions.SessionState) bool | 	ValidateSession(ctx context.Context, s *sessions.SessionState) bool | ||||||
| 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) | 	RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||||
| 	CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) | 	CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue