Use `ErrNotImplemented` in default refresh implementation
This commit is contained in:
		
							parent
							
								
									baf6cf3816
								
							
						
					
					
						commit
						ff914d7e17
					
				|  | @ -4,10 +4,16 @@ | |||
| 
 | ||||
| ## Important Notes | ||||
| 
 | ||||
| - [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) The extra validation to protect invalid session | ||||
|   deserialization from v6.0.0 (only) has been removed to improve performance. If you are on v6.0.0, either upgrade | ||||
|   to a version before this first and allow legacy sessions to expire gracefully or change your `cookie-secret` | ||||
|   value and force all sessions to reauthenticate. | ||||
| 
 | ||||
| ## Breaking Changes | ||||
| 
 | ||||
| ## Changes since v7.1.3 | ||||
| 
 | ||||
| - [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) Refresh sessions before token expiration if configured (@NickMeves) | ||||
| - [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed) | ||||
| - [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed) | ||||
| - [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi) | ||||
|  |  | |||
|  | @ -11,19 +11,20 @@ import ( | |||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/providers" | ||||
| ) | ||||
| 
 | ||||
| // StoredSessionLoaderOptions cotnains all of the requirements to construct
 | ||||
| // StoredSessionLoaderOptions contains all of the requirements to construct
 | ||||
| // a stored session loader.
 | ||||
| // All options must be provided.
 | ||||
| type StoredSessionLoaderOptions struct { | ||||
| 	// Session storage basckend
 | ||||
| 	// Session storage backend
 | ||||
| 	SessionStore sessionsapi.SessionStore | ||||
| 
 | ||||
| 	// How often should sessions be refreshed
 | ||||
| 	RefreshPeriod time.Duration | ||||
| 
 | ||||
| 	// Provider based sesssion refreshing
 | ||||
| 	// Provider based session refreshing
 | ||||
| 	RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||
| 
 | ||||
| 	// Provider based session validation.
 | ||||
|  | @ -115,7 +116,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req | |||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) | ||||
| 	logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) | ||||
| 	err := s.refreshSession(rw, req, session) | ||||
| 	if err != nil { | ||||
| 		// If a preemptive refresh fails, we still keep the session
 | ||||
|  | @ -131,21 +132,27 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req | |||
| // and will save the session if it was updated.
 | ||||
| func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||
| 	refreshed, err := s.sessionRefresher(req.Context(), session) | ||||
| 	if err != nil { | ||||
| 	if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | ||||
| 		return fmt.Errorf("error refreshing tokens: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// HACK:
 | ||||
| 	// Providers that don't implement `RefreshSession` use the default
 | ||||
| 	// implementation which returns `ErrNotImplemented`.
 | ||||
| 	// Pretend it refreshed to reset the refresh timer so that `ValidateSession`
 | ||||
| 	// isn't triggered every subsequent request and is only called once during
 | ||||
| 	// this request.
 | ||||
| 	if errors.Is(err, providers.ErrNotImplemented) { | ||||
| 		refreshed = true | ||||
| 	} | ||||
| 
 | ||||
| 	// Session not refreshed, nothing to persist.
 | ||||
| 	if !refreshed { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	// If we refreshed, update the `CreatedAt` time to reset the refresh timer
 | ||||
| 	//
 | ||||
| 	// HACK:
 | ||||
| 	// Providers that don't implement `RefreshSession` use the default
 | ||||
| 	// implementation. It always returns `refreshed == true`, so the
 | ||||
| 	// `session.CreatedAt` is updated and doesn't trigger `ValidateSession`
 | ||||
| 	// every subsequent request.
 | ||||
| 	// (In case underlying provider implementations forget)
 | ||||
| 	session.CreatedAtNow() | ||||
| 
 | ||||
| 	// Because the session was refreshed, make sure to save it
 | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ import ( | |||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/providers" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
|  | @ -18,8 +19,9 @@ import ( | |||
| 
 | ||||
| var _ = Describe("Stored Session Suite", func() { | ||||
| 	const ( | ||||
| 		refresh   = "Refresh" | ||||
| 		noRefresh = "NoRefresh" | ||||
| 		refresh        = "Refresh" | ||||
| 		noRefresh      = "NoRefresh" | ||||
| 		notImplemented = "NotImplemented" | ||||
| 	) | ||||
| 
 | ||||
| 	var ctx = context.Background() | ||||
|  | @ -293,6 +295,8 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 							return true, nil | ||||
| 						case noRefresh: | ||||
| 							return false, nil | ||||
| 						case notImplemented: | ||||
| 							return false, providers.ErrNotImplemented | ||||
| 						default: | ||||
| 							return false, errors.New("error refreshing session") | ||||
| 						} | ||||
|  | @ -364,6 +368,16 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 				expectRefreshed: true, | ||||
| 				expectValidated: true, | ||||
| 			}), | ||||
| 			Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: notImplemented, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: true, | ||||
| 			}), | ||||
| 			Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
|  | @ -418,6 +432,8 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 							return true, nil | ||||
| 						case noRefresh: | ||||
| 							return false, nil | ||||
| 						case notImplemented: | ||||
| 							return false, providers.ErrNotImplemented | ||||
| 						default: | ||||
| 							return false, errors.New("error refreshing session") | ||||
| 						} | ||||
|  | @ -448,6 +464,13 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 				expectedErr: nil, | ||||
| 				expectSaved: true, | ||||
| 			}), | ||||
| 			Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: notImplemented, | ||||
| 				}, | ||||
| 				expectedErr: nil, | ||||
| 				expectSaved: true, | ||||
| 			}), | ||||
| 			Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "RefreshError", | ||||
|  |  | |||
|  | @ -130,19 +130,8 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS | |||
| } | ||||
| 
 | ||||
| // RefreshSession refreshes the user's session
 | ||||
| func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 	if s == nil { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// HACK:
 | ||||
| 	// Pretend `RefreshSession` occurred so `ValidateSession` isn't called
 | ||||
| 	// on every request after any potential set refresh period elapses.
 | ||||
| 	// See `middleware.refreshSession` for detailed logic & explanation.
 | ||||
| 	//
 | ||||
| 	// Intentionally doesn't use `ErrNotImplemented` since all providers will
 | ||||
| 	// call this and we don't want to force them to implement this dummy logic.
 | ||||
| 	return true, nil | ||||
| func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) { | ||||
| 	return false, ErrNotImplemented | ||||
| } | ||||
| 
 | ||||
| // CreateSessionFromToken converts Bearer IDTokens into sessions
 | ||||
|  |  | |||
|  | @ -22,12 +22,12 @@ func TestRefresh(t *testing.T) { | |||
| 	ss.SetExpiresOn(expires) | ||||
| 
 | ||||
| 	refreshed, err := p.RefreshSession(context.Background(), ss) | ||||
| 	assert.True(t, refreshed) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.False(t, refreshed) | ||||
| 	assert.Equal(t, ErrNotImplemented, err) | ||||
| 
 | ||||
| 	refreshed, err = p.RefreshSession(context.Background(), nil) | ||||
| 	assert.False(t, refreshed) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, ErrNotImplemented, err) | ||||
| } | ||||
| 
 | ||||
| func TestAcrValuesNotConfigured(t *testing.T) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue