diff --git a/oauthproxy.go b/oauthproxy.go index f8dc5471..16c90231 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -430,10 +430,13 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi } chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ - SessionStore: sessionStore, - RefreshPeriod: opts.Cookie.Refresh, - RefreshSession: provider.RefreshSession, - ValidateSession: provider.ValidateSession, + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + SessionRefreshLockDuration: opts.Cookie.SessionRefreshLockDuration, + SessionRefreshObtainTimeout: opts.Cookie.SessionRefreshObtainTimeout, + SessionRefreshRetryPeriod: opts.Cookie.SessionRefreshRetryPeriod, + RefreshSession: provider.RefreshSession, + ValidateSession: provider.ValidateSession, })) return chain diff --git a/pkg/apis/options/cookie.go b/pkg/apis/options/cookie.go index 3dee9505..763eb487 100644 --- a/pkg/apis/options/cookie.go +++ b/pkg/apis/options/cookie.go @@ -11,20 +11,23 @@ import ( // Cookie contains configuration options relating to Cookie configuration type Cookie struct { - Name string `flag:"cookie-name" cfg:"cookie_name"` - Secret string `flag:"cookie-secret" cfg:"cookie_secret"` - SecretFile string `flag:"cookie-secret-file" cfg:"cookie_secret_file"` - Domains []string `flag:"cookie-domain" cfg:"cookie_domains"` - Path string `flag:"cookie-path" cfg:"cookie_path"` - Expire time.Duration `flag:"cookie-expire" cfg:"cookie_expire"` - Refresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh"` - Secure bool `flag:"cookie-secure" cfg:"cookie_secure"` - HTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` - SameSite string `flag:"cookie-samesite" cfg:"cookie_samesite"` - CSRFPerRequest bool `flag:"cookie-csrf-per-request" cfg:"cookie_csrf_per_request"` - CSRFPerRequestLimit int `flag:"cookie-csrf-per-request-limit" cfg:"cookie_csrf_per_request_limit"` - CSRFExpire time.Duration `flag:"cookie-csrf-expire" cfg:"cookie_csrf_expire"` - CSRFSameSite string `flag:"cookie-csrf-samesite" cfg:"cookie_csrf_samesite"` + Name string `flag:"cookie-name" cfg:"cookie_name"` + Secret string `flag:"cookie-secret" cfg:"cookie_secret"` + SecretFile string `flag:"cookie-secret-file" cfg:"cookie_secret_file"` + Domains []string `flag:"cookie-domain" cfg:"cookie_domains"` + Path string `flag:"cookie-path" cfg:"cookie_path"` + Expire time.Duration `flag:"cookie-expire" cfg:"cookie_expire"` + Refresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh"` + Secure bool `flag:"cookie-secure" cfg:"cookie_secure"` + HTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` + SameSite string `flag:"cookie-samesite" cfg:"cookie_samesite"` + CSRFPerRequest bool `flag:"cookie-csrf-per-request" cfg:"cookie_csrf_per_request"` + CSRFPerRequestLimit int `flag:"cookie-csrf-per-request-limit" cfg:"cookie_csrf_per_request_limit"` + CSRFExpire time.Duration `flag:"cookie-csrf-expire" cfg:"cookie_csrf_expire"` + CSRFSameSite string `flag:"cookie-csrf-samesite" cfg:"cookie_csrf_samesite"` + SessionRefreshLockDuration time.Duration `flag:"session-refresh-lock-duration" cfg:"session_refresh_lock_duration"` + SessionRefreshObtainTimeout time.Duration `flag:"session-refresh-obtain-timeout" cfg:"session_refresh_obtain_timeout"` + SessionRefreshRetryPeriod time.Duration `flag:"session-refresh-retry-period" cfg:"session_refresh_retry_period"` } func cookieFlagSet() *pflag.FlagSet { @@ -44,26 +47,32 @@ func cookieFlagSet() *pflag.FlagSet { flagSet.Int("cookie-csrf-per-request-limit", 0, "Sets a limit on the number of CSRF requests cookies that oauth2-proxy will create. The oldest cookies will be removed. Useful if users end up with 431 Request headers too large status codes.") flagSet.Duration("cookie-csrf-expire", time.Duration(15)*time.Minute, "expire timeframe for CSRF cookie") flagSet.String("cookie-csrf-samesite", "", "set SameSite CSRF cookie attribute (ie: \"lax\", \"strict\", \"none\", or \"\"). When using the default setting, the CSRF cookie samesite value is taken from the session cookie configuration.") + flagSet.Duration("session-refresh-lock-duration", time.Duration(2)*time.Second, "maximum time allowed for a session refresh attempt; if the refresh request isn't finished within this time, the lock will be released") + flagSet.Duration("session-refresh-obtain-timeout", time.Duration(5)*time.Second, "timeout when attempting to obtain the session lock; if the lock is not obtained before this timeout, the refresh attempt will fail") + flagSet.Duration("session-refresh-retry-period", time.Duration(10)*time.Millisecond, "how long to wait after failing to obtain the lock before trying again") return flagSet } // cookieDefaults creates a Cookie populating each field with its default value func cookieDefaults() Cookie { return Cookie{ - Name: "_oauth2_proxy", - Secret: "", - SecretFile: "", - Domains: nil, - Path: "/", - Expire: time.Duration(168) * time.Hour, - Refresh: time.Duration(0), - Secure: true, - HTTPOnly: true, - SameSite: "", - CSRFPerRequest: false, - CSRFPerRequestLimit: 0, - CSRFExpire: time.Duration(15) * time.Minute, - CSRFSameSite: "", + Name: "_oauth2_proxy", + Secret: "", + SecretFile: "", + Domains: nil, + Path: "/", + Expire: time.Duration(168) * time.Hour, + Refresh: time.Duration(0), + Secure: true, + HTTPOnly: true, + SameSite: "", + CSRFPerRequest: false, + CSRFPerRequestLimit: 0, + CSRFExpire: time.Duration(15) * time.Minute, + CSRFSameSite: "", + SessionRefreshLockDuration: time.Duration(2) * time.Second, + SessionRefreshObtainTimeout: time.Duration(5) * time.Second, + SessionRefreshRetryPeriod: time.Duration(10) * time.Millisecond, } } diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 53238f19..46a4f2fd 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -15,23 +15,6 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) -const ( - // When attempting to obtain the lock, if it's not done before this timeout - // then exit and fail the refresh attempt. - // TODO: This should probably be configurable by the end user. - sessionRefreshObtainTimeout = 5 * time.Second - - // Maximum time allowed for a session refresh attempt. - // If the refresh request isn't finished within this time, the lock will be - // released. - // TODO: This should probably be configurable by the end user. - sessionRefreshLockDuration = 2 * time.Second - - // How long to wait after failing to obtain the lock before trying again. - // TODO: This should probably be configurable by the end user. - sessionRefreshRetryPeriod = 10 * time.Millisecond -) - // isFatalRefreshError checks if a refresh error indicates a revoked or // non-existent session that should be immediately invalidated. // Fatal errors indicate the session is no longer valid at the provider level. @@ -59,6 +42,7 @@ func isFatalRefreshError(err error) bool { return false } + // StoredSessionLoaderOptions contains all of the requirements to construct // a stored session loader. // All options must be provided. @@ -69,6 +53,17 @@ type StoredSessionLoaderOptions struct { // How often should sessions be refreshed RefreshPeriod time.Duration + // Maximum time allowed for a session refresh attempt. + // If the refresh request isn't finished within this time, the lock will be released. + SessionRefreshLockDuration time.Duration + + // Timeout when attempting to obtain the session lock. + // If the lock is not obtained before this timeout, the refresh attempt will fail. + SessionRefreshObtainTimeout time.Duration + + // How long to wait after failing to obtain the lock before trying again. + SessionRefreshRetryPeriod time.Duration + // Provider based session refreshing RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) @@ -84,10 +79,13 @@ type StoredSessionLoaderOptions struct { // If a session was loader by a previous handler, it will not be replaced. func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { ss := &storedSessionLoader{ - store: opts.SessionStore, - refreshPeriod: opts.RefreshPeriod, - sessionRefresher: opts.RefreshSession, - sessionValidator: opts.ValidateSession, + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + sessionRefreshLockDuration: opts.SessionRefreshLockDuration, + sessionRefreshObtainTimeout: opts.SessionRefreshObtainTimeout, + sessionRefreshRetryPeriod: opts.SessionRefreshRetryPeriod, + sessionRefresher: opts.RefreshSession, + sessionValidator: opts.ValidateSession, } return ss.loadSession } @@ -95,10 +93,13 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor // storedSessionLoader is responsible for loading sessions from cookie // identified sessions in the session store. type storedSessionLoader struct { - store sessionsapi.SessionStore - refreshPeriod time.Duration - sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) - sessionValidator func(context.Context, *sessionsapi.SessionState) bool + store sessionsapi.SessionStore + refreshPeriod time.Duration + sessionRefreshLockDuration time.Duration + sessionRefreshObtainTimeout time.Duration + sessionRefreshRetryPeriod time.Duration + sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) + sessionValidator func(context.Context, *sessionsapi.SessionState) bool } // loadSession attempts to load a session as identified by the request cookies. @@ -159,7 +160,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req } var lockObtained bool - ctx, cancel := context.WithTimeout(context.Background(), sessionRefreshObtainTimeout) + ctx, cancel := context.WithTimeout(context.Background(), s.sessionRefreshObtainTimeout) defer cancel() for !lockObtained { @@ -167,11 +168,11 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req case <-ctx.Done(): return errors.New("timeout obtaining session lock") default: - err := session.ObtainLock(req.Context(), sessionRefreshLockDuration) + err := session.ObtainLock(req.Context(), s.sessionRefreshLockDuration) if err != nil && !errors.Is(err, sessionsapi.ErrLockNotObtained) { return fmt.Errorf("error occurred while trying to obtain lock: %v", err) } else if errors.Is(err, sessionsapi.ErrLockNotObtained) { - time.Sleep(sessionRefreshRetryPeriod) + time.Sleep(s.sessionRefreshRetryPeriod) continue } // No error means we obtained the lock diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index c913a4ef..08bf74c3 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -184,10 +184,13 @@ var _ = Describe("Stored Session Suite", func() { rw := httptest.NewRecorder() opts := &StoredSessionLoaderOptions{ - SessionStore: in.store, - RefreshPeriod: in.refreshPeriod, - RefreshSession: in.refreshSession, - ValidateSession: in.validateSession, + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + SessionRefreshLockDuration: 2 * time.Second, + SessionRefreshObtainTimeout: 5 * time.Second, + SessionRefreshRetryPeriod: 10 * time.Millisecond, + RefreshSession: in.refreshSession, + ValidateSession: in.validateSession, } // Create the handler with a next handler that will capture the session @@ -384,8 +387,11 @@ var _ = Describe("Stored Session Suite", func() { sessionRefreshed := false opts := &StoredSessionLoaderOptions{ - SessionStore: store, - RefreshPeriod: in.refreshPeriod, + SessionStore: store, + RefreshPeriod: in.refreshPeriod, + SessionRefreshLockDuration: 2 * time.Second, + SessionRefreshObtainTimeout: 5 * time.Second, + SessionRefreshRetryPeriod: 10 * time.Millisecond, RefreshSession: func(ctx context.Context, s *sessionsapi.SessionState) (bool, error) { time.Sleep(10 * time.Millisecond) sessionRefreshed = true @@ -480,8 +486,11 @@ var _ = Describe("Stored Session Suite", func() { } s := &storedSessionLoader{ - refreshPeriod: in.refreshPeriod, - store: store, + refreshPeriod: in.refreshPeriod, + store: store, + sessionRefreshLockDuration: 2 * time.Second, + sessionRefreshObtainTimeout: 5 * time.Second, + sessionRefreshRetryPeriod: 10 * time.Millisecond, sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { refreshed = true switch ss.RefreshToken {