Add automated retries for IdP and CSRF errors

This implements functionality to automate restarting a login.
The automation of CSRF errors catches missing CSRF cookies to retry as fallback for Logins that required too much time.
The IdP errors are implemented because some IdP Providers return with errors where it is advised to directly retry the login process. For example keycloak handles its "temporarily_unavailable" error like this. See https://www.keycloak.org/securing-apps/oidc-layers#_oidc-errors

Signed-off-by: Tim Fiernkranz <tim.fiernkranz@tngtech.com>
This commit is contained in:
Tim Fiernkranz 2025-08-26 11:05:20 +02:00
parent 3a55dadbe8
commit c09b909dd3
5 changed files with 156 additions and 44 deletions

View File

@ -7,6 +7,7 @@
## Breaking Changes
## Changes since v7.14.2
- [#3177](https://github.com/oauth2-proxy/oauth2-proxy/pull/3177) feat: add options to automate retries for IdP or CSRF erros (@TimFiernkranz)
# V7.14.2

View File

@ -77,7 +77,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
Provider specific options can be found on their respective subpages.
| Flag / Config Field | Type | Description | Default |
| --------------------------------------------------------------------------------------------------- | -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------- |
| --------------------------------------------------------------------------------------------------- |----------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| --------------------- |
| flag: `--acr-values`<br/>toml: `acr_values` | string | optional, see [docs](https://openid.net/specs/openid-connect-eap-acr-values-1_0.html#acrValues) | `""` |
| flag: `--allowed-group`<br/>toml: `allowed_groups` | string \| list | Restrict login to members of a group or list of groups. Furthermore, if you aren't setting the `scope` and use `allowed_groups` with the generic OIDC provider the scope `groups` gets added implicitly. | |
| flag: `--approval-prompt`<br/>toml: `approval_prompt` | string | OAuth approval_prompt | `"force"` |
@ -112,14 +112,16 @@ Provider specific options can be found on their respective subpages.
| flag: `--skip-oidc-discovery`<br/>toml: `skip_oidc_discovery` | bool | bypass OIDC endpoint discovery. `--login-url`, `--redeem-url` and `--oidc-jwks-url` must be configured in this case | false |
| flag: `--use-system-trust-store`<br/>toml: `use_system_trust_store` | bool | Determines if `provider-ca-file` files and the system trust store are used. If set to true, your custom CA files and the system trust store are used otherwise only your custom CA files. | false |
| flag: `--validate-url`<br/>toml: `validate_url` | string | Access token validation endpoint | |
| flag: `--retry-idp-errors`<br/>toml: `retry_idp_errors` | string \| list | Errors returned from the IdP that should trigger an automated retry. Also increase `max-automated-retries` if this is used. | |
### Cookie Options
| Flag / Config Field | Type | Description | Default |
| --------------------------------------------------------------------------------- | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------- |
|-----------------------------------------------------------------------------------|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------|
| flag: `--cookie-csrf-expire`<br/>toml: `cookie_csrf_expire` | duration | expire timeframe for CSRF cookie | 15m |
| flag: `--cookie-csrf-per-request`<br/>toml:`cookie_csrf_per_request` | bool | Enable having different CSRF cookies per request, making it possible to have parallel requests. | false |
| flag: `--cookie-csrf-per-request-limit`<br/>toml: `cookie_csrf_per_request_limit` | int | Sets a limit on the number of CSRF requests cookies that oauth2-proxy will create. The oldest cookie will be removed. Useful if users end up with 431 Request headers too large status codes. Only effective if --cookie-csrf-per-request is true | "infinite" |
| flag: `--retry-csrf-errors`<br/>toml: `retry_csrf_errors` | bool | If set to true the oauth_proxy will automatically restart the authentication flow if an error with CSRF cookies occurs. To work `max-automated-retry` also needs to be increased. | false |
| flag: `--cookie-domain`<br/>toml: `cookie_domains` | string \| list | Optional cookie domains to force cookies to (e.g. `.yourcompany.com`). The longest domain matching the request's host will be used (or the shortest cookie domain if there is no match). | |
| flag: `--cookie-expire`<br/>toml: `cookie_expire` | duration | expire timeframe for cookie. If set to 0, cookie becomes a session-cookie which will expire when the browser is closed. | 168h0m0s |
| flag: `--cookie-httponly`<br/>toml: `cookie_httponly` | bool | set HttpOnly cookie flag | true |
@ -192,7 +194,7 @@ Provider specific options can be found on their respective subpages.
### Proxy Options
| Flag / Config Field | Type | Description | Default |
| ----------------------------------------------------------------------------- | -------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------- |
| ----------------------------------------------------------------------------- |----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------|
| flag: `--allow-query-semicolons`<br/>toml: `allow_query_semicolons` | bool | allow the use of semicolons in query args ([required for some legacy applications](https://github.com/golang/go/issues/25192)) | `false` |
| flag: `--api-route`<br/>toml: `api_routes` | string \| list | Requests to these paths must already be authenticated with a cookie, or a JWT if `--skip-jwt-bearer-tokens` is set. No redirect to login will be done. Return 401 if not. Format: path_regex | |
| flag: `--authenticated-emails-file`<br/>toml: `authenticated_emails_file` | string | authenticate against emails via file (one per line) | |
@ -218,6 +220,7 @@ Provider specific options can be found on their respective subpages.
| flag: `--ssl-insecure-skip-verify`<br/>toml: `ssl_insecure_skip_verify` | bool | skip validation of certificates presented when using HTTPS providers | false |
| flag: `--trusted-ip`<br/>toml: `trusted_ips` | string \| list | list of IPs or CIDR ranges to allow to bypass authentication (may be given multiple times). When combined with `--reverse-proxy` and optionally `--real-client-ip-header` this will evaluate the trust of the IP stored in an HTTP header by a reverse proxy rather than the layer-3/4 remote address. WARNING: trusting IPs has inherent security flaws, especially when obtaining the IP address from an HTTP header (reverse-proxy mode). Use this option only if you understand the risks and how to manage them. | |
| flag: `--whitelist-domain`<br/>toml: `whitelist_domains` | string \| list | allowed domains for redirection after authentication. Prefix domain with a `.` or a `*.` to allow subdomains (e.g. `.example.com`, `*.example.com`)&nbsp;[^2] | |
| flag: `--max-automated-retries`<br/>toml: `max_automated_retries` | int | Maximum amount of automated retries that the oauth_proxy is attempting. | 0 |
[^2]: When using the `whitelist-domain` option, any domain prefixed with a `.` or a `*.` will allow any subdomain of the specified domain as a valid redirect URL. By default, only empty ports are allowed. This translates to allowing the default port of the URL's protocol (80 for HTTP, 443 for HTTPS, etc.) since browsers omit them. To allow only a specific port, add it to the whitelisted domain: `example.com:8080`. To allow any port, use `*`: `example.com:*`.

View File

@ -13,6 +13,8 @@ import (
"os"
"os/signal"
"regexp"
"slices"
"strconv"
"strings"
"syscall"
"time"
@ -113,7 +115,10 @@ type OAuthProxy struct {
redirectValidator redirect.Validator
appDirector redirect.AppDirector
encodeState bool
encodeState bool
maxAutomatedRetries int
idpErrorsToRetry []string
retryCsrfErrors bool
}
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
@ -238,16 +243,19 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
allowQuerySemicolons: opts.AllowQuerySemicolons,
trustedIPs: trustedIPs,
basicAuthValidator: basicAuthValidator,
basicAuthGroups: opts.HtpasswdUserGroups,
sessionChain: sessionChain,
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
upstreamProxy: upstreamProxy,
redirectValidator: redirectValidator,
appDirector: appDirector,
encodeState: opts.EncodeState,
basicAuthValidator: basicAuthValidator,
basicAuthGroups: opts.HtpasswdUserGroups,
sessionChain: sessionChain,
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
upstreamProxy: upstreamProxy,
redirectValidator: redirectValidator,
appDirector: appDirector,
encodeState: opts.EncodeState,
maxAutomatedRetries: opts.MaxAutomatedRetries,
idpErrorsToRetry: opts.IdpErrorsToRetry,
retryCsrfErrors: opts.RetryCsrfErrors,
}
p.buildServeMux(opts.ProxyPrefix)
@ -789,13 +797,32 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
}
}
// OAuthRestart restarts the OAuth2 authentication flow with a specified retryCounter
func (p *OAuthProxy) OAuthRestart(rw http.ResponseWriter, req *http.Request, redirectPath string, retryCount int) (string, error) {
if !p.redirectValidator.IsValidRedirect(redirectPath) {
return fmt.Sprintf("Login Failed: The given redirect (%v) was not valid. Please try again.", redirectPath), errors.New("invalid redirect")
}
if retryCount >= p.maxAutomatedRetries {
return fmt.Sprintf("Login Failed: The maximum amount (%v) of automated retries exceeded.", p.maxAutomatedRetries), errors.New("retries exceeded")
}
logger.Printf("Restarting Oauth2 flow (%v/%v)", retryCount, p.maxAutomatedRetries)
// Handlers typically use this parameter to preserve the login redirect path.
// Since initialization now occurs within the proxy, we must set this parameter
// to retain the originally requested path for restarted requests.
req.Form.Set("rd", redirectPath)
p.doOAuthStart(rw, req, nil, retryCount+1)
return "", nil
}
// OAuthStart starts the OAuth2 authentication flow
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
// start the flow permitting login URL query parameters to be overridden from the request URL
p.doOAuthStart(rw, req, req.URL.Query())
p.doOAuthStart(rw, req, req.URL.Query(), 0)
}
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values) {
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values, retryCount int) {
extraParams := p.provider.Data().LoginURLParams(overrides)
prepareNoCache(rw)
@ -840,7 +867,7 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove
callbackRedirect := p.getOAuthRedirectURI(req)
loginURL := p.provider.GetLoginURL(
callbackRedirect,
encodeState(csrf.HashOAuthState(), appRedirect, p.encodeState),
encodeState(csrf.HashOAuthState(), appRedirect, retryCount, p.encodeState),
csrf.HashOIDCNonce(),
extraParams,
)
@ -865,22 +892,33 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
errorString := req.Form.Get("error")
if errorString != "" {
logger.Errorf("Error while parsing OAuth2 callback: %s", errorString)
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
// Set the debug message and override the non debug message to be the same for this case
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
return
}
nonce, appRedirect, err := decodeState(req.Form.Get("state"), p.encodeState)
nonce, appRedirect, retries, err := decodeState(req.Form.Get("state"), p.encodeState)
if err != nil {
logger.Errorf("Error while parsing OAuth2 state: %v", err)
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
errorString := req.Form.Get("error")
if errorString != "" {
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
logger.Error(message)
if slices.Contains(p.idpErrorsToRetry, errorString) {
restartErrString, restartErr := p.OAuthRestart(rw, req, appRedirect, retries)
if restartErr != nil {
logger.Errorf("Encountered Error while restarting: %s", restartErr)
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
}
return
}
// Set the debug message and override the non debug message to be the same for this case
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
return
}
// calculate the cookie name
cookieName := cookies.GenerateCookieName(p.CookieOptions, nonce)
// Try to find the CSRF cookie and decode it
@ -889,8 +927,16 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
// There are a lot of issues opened complaining about missing CSRF cookies.
// Try to log the INs and OUTs of OAuthProxy, to be easier to analyse these issues.
LoggingCSRFCookiesInOAuthCallback(req, cookieName)
logger.Println(req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.")
message := fmt.Sprintf("Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
logger.Println(req, logger.AuthFailure, message)
restartErrString, restartErr := p.OAuthRestart(rw, req, appRedirect, retries)
if restartErr != nil {
logger.Errorf("Encountered Error while restarting: %s", restartErr)
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
}
return
}
@ -933,7 +979,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
logger.Errorf("Error with authorization: %v", err)
}
if p.Validator(session.Email) && authorized {
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2 after %v retries: %s", retries, session)
err := p.SaveSession(rw, req, session)
if err != nil {
logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
@ -1037,7 +1083,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
// start OAuth flow, but only with the default login URL params - do not
// consider this request's query params as potential overrides, since
// the user did not explicitly start the login flow
p.doOAuthStart(rw, req, nil)
p.doOAuthStart(rw, req, nil, 0)
} else {
p.SignInPage(rw, req, http.StatusForbidden)
}
@ -1250,8 +1296,9 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool {
// encodeState builds the OAuth state param out of our nonce and
// original application redirect
func encodeState(nonce string, redirect string, encode bool) string {
rawString := fmt.Sprintf("%v:%v", nonce, redirect)
func encodeState(nonce string, redirect string, retryCount int, encode bool) string {
redirectEscaped := url.QueryEscape(redirect)
rawString := fmt.Sprintf("%v:%v:%v", nonce, redirectEscaped, retryCount)
if encode {
return base64.RawURLEncoding.EncodeToString([]byte(rawString))
}
@ -1260,18 +1307,36 @@ func encodeState(nonce string, redirect string, encode bool) string {
// decodeState splits the reflected OAuth state response back into
// the nonce and original application redirect
func decodeState(state string, encode bool) (string, string, error) {
func decodeState(state string, encode bool) (string, string, int, error) {
toParse := state
if encode {
decoded, _ := base64.RawURLEncoding.DecodeString(state)
toParse = string(decoded)
}
parsedState := strings.SplitN(toParse, ":", 2)
if len(parsedState) != 2 {
return "", "", errors.New("invalid length")
parsedState := strings.Split(toParse, ":")
numStateParams := len(parsedState)
if numStateParams < 2 || numStateParams > 3 {
return "", "", -1, fmt.Errorf("invalid number of state parameters (%v)", numStateParams)
}
return parsedState[0], parsedState[1], nil
nonce := parsedState[0]
redirectPath, err := url.QueryUnescape(parsedState[1])
if err != nil {
return "", "", -1, fmt.Errorf("invalid redirectPath url (%v)", parsedState[1])
}
retries := 0
if numStateParams > 2 {
retries, err = strconv.Atoi(parsedState[2])
if err != nil {
return "", "", -1, fmt.Errorf("invalid retry count (%v)", parsedState[2])
}
}
return nonce, redirectPath, retries, nil
}
// addHeadersForProxying adds the appropriate headers the request / response for proxying

View File

@ -414,7 +414,7 @@ func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie
http.MethodGet,
fmt.Sprintf(
"/oauth2/callback?code=callback_code&state=%s",
encodeState(csrf.HashOAuthState(), "%2F", false),
encodeState(csrf.HashOAuthState(), "%2F", 0, false),
),
strings.NewReader(""),
)
@ -3358,24 +3358,61 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) {
func TestStateEncodesCorrectly(t *testing.T) {
state := "some_state_to_test"
nonce := "some_nonce_to_test"
retryCount := 3
encodedResult := encodeState(nonce, state, true)
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult)
encodedResult := encodeState(nonce, state, retryCount, true)
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", encodedResult)
notEncodedResult := encodeState(nonce, state, false)
assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult)
notEncodedResult := encodeState(nonce, state, retryCount, false)
assert.Equal(t, "some_nonce_to_test:some_state_to_test:3", notEncodedResult)
}
func TestStateEncodesWithColonInRedirect(t *testing.T) {
redirect := "test:url"
nonce := "nonce_value"
retryCount := 3
encodedResult := encodeState(nonce, redirect, retryCount, true)
assert.Equal(t, "bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", encodedResult)
notEncodedResult := encodeState(nonce, redirect, retryCount, false)
assert.Equal(t, "nonce_value:test%3Aurl:3", notEncodedResult)
}
func TestStateDecodesCorrectly(t *testing.T) {
nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true)
nonce, redirect, retryCount, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", true)
assert.Equal(t, "some_nonce_to_test", nonce)
assert.Equal(t, "some_state_to_test", redirect)
assert.Equal(t, 3, retryCount)
nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false)
nonce2, redirect2, retryCount2, _ := decodeState("some_nonce_to_test:some_state_to_test:3", false)
assert.Equal(t, "some_nonce_to_test", nonce2)
assert.Equal(t, "some_state_to_test", redirect2)
assert.Equal(t, 3, retryCount2)
}
func TestStateDecodesWithColonInRedirect(t *testing.T) {
nonce, redirect, retryCount, _ := decodeState("bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", true)
assert.Equal(t, "nonce_value", nonce)
assert.Equal(t, "test:url", redirect)
assert.Equal(t, 3, retryCount)
nonce2, redirect2, retryCount2, _ := decodeState("nonce_value:test%3Aurl:3", false)
assert.Equal(t, "nonce_value", nonce2)
assert.Equal(t, "test:url", redirect2)
assert.Equal(t, 3, retryCount2)
}
func TestStateDecodesWithoutRetryCount(t *testing.T) {
nonce, redirect, retryCount, _ := decodeState("nonce_value:url", false)
assert.Equal(t, "nonce_value", nonce)
assert.Equal(t, "url", redirect)
assert.Equal(t, 0, retryCount)
}
func TestAuthOnlyAllowedEmails(t *testing.T) {

View File

@ -34,6 +34,9 @@ type Options struct {
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"`
HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"`
HtpasswdUserGroups []string `flag:"htpasswd-user-group" cfg:"htpasswd_user_groups"`
MaxAutomatedRetries int `flag:"max-automated-retries" cfg:"max_automated_retries"`
IdpErrorsToRetry []string `flag:"retry-idp-errors" cfg:"retry_idp_errors"`
RetryCsrfErrors bool `flag:"retry-csrf-errors" cfg:"retry_csrf_errors"`
Cookie Cookie `cfg:",squash"`
Session SessionOptions `cfg:",squash"`
@ -161,6 +164,9 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option")
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
flagSet.Int("max-automated-retries", 0, "Maximum number of automated retries for callback errors")
flagSet.StringSlice("retry-idp-errors", []string{}, "Errors from IdP that should be automatically retried")
flagSet.Bool("retry-csrf-errors", false, "If true retries Csrf errors automatically")
flagSet.AddFlagSet(cookieFlagSet())
flagSet.AddFlagSet(loggingFlagSet())