diff --git a/oauthproxy.go b/oauthproxy.go index 7526d641..0f8ad2bf 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -13,6 +13,8 @@ import ( "os" "os/signal" "regexp" + "slices" + "strconv" "strings" "syscall" "time" @@ -113,7 +115,10 @@ type OAuthProxy struct { redirectValidator redirect.Validator appDirector redirect.AppDirector - encodeState bool + encodeState bool + maxAutomatedRetries int + idpErrorsToRetry []string + retryCsrfErrors bool } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided @@ -238,16 +243,19 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr allowQuerySemicolons: opts.AllowQuerySemicolons, trustedIPs: trustedIPs, - basicAuthValidator: basicAuthValidator, - basicAuthGroups: opts.HtpasswdUserGroups, - sessionChain: sessionChain, - headersChain: headersChain, - preAuthChain: preAuthChain, - pageWriter: pageWriter, - upstreamProxy: upstreamProxy, - redirectValidator: redirectValidator, - appDirector: appDirector, - encodeState: opts.EncodeState, + basicAuthValidator: basicAuthValidator, + basicAuthGroups: opts.HtpasswdUserGroups, + sessionChain: sessionChain, + headersChain: headersChain, + preAuthChain: preAuthChain, + pageWriter: pageWriter, + upstreamProxy: upstreamProxy, + redirectValidator: redirectValidator, + appDirector: appDirector, + encodeState: opts.EncodeState, + maxAutomatedRetries: opts.MaxAutomatedRetries, + idpErrorsToRetry: opts.IdpErrorsToRetry, + retryCsrfErrors: opts.RetryCsrfErrors, } p.buildServeMux(opts.ProxyPrefix) @@ -788,13 +796,32 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) { } } +// OAuthRestart restarts the OAuth2 authentication flow with a specified retryCounter +func (p *OAuthProxy) OAuthRestart(rw http.ResponseWriter, req *http.Request, redirectPath string, retryCount int) (error, string) { + if !p.redirectValidator.IsValidRedirect(redirectPath) { + return errors.New("invalid redirect"), fmt.Sprintf("Login Failed: The given redirect (%v) was not valid. Please try again.", redirectPath) + } + + if retryCount >= p.maxAutomatedRetries { + return errors.New("retries exceeded"), fmt.Sprintf("Login Failed: The maximum amount (%v) of automated retries exceeded.", p.maxAutomatedRetries) + } + + logger.Printf("Restarting Oauth2 flow (%v/%v)", retryCount, p.maxAutomatedRetries) + // Handlers typically use this parameter to preserve the login redirect path. + // Since initialization now occurs within the proxy, we must set this parameter + // to retain the originally requested path for restarted requests. + req.Form.Set("rd", redirectPath) + p.doOAuthStart(rw, req, nil, retryCount+1) + return nil, "" +} + // OAuthStart starts the OAuth2 authentication flow func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { // start the flow permitting login URL query parameters to be overridden from the request URL - p.doOAuthStart(rw, req, req.URL.Query()) + p.doOAuthStart(rw, req, req.URL.Query(), 0) } -func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values) { +func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values, retryCount int) { extraParams := p.provider.Data().LoginURLParams(overrides) prepareNoCache(rw) @@ -839,7 +866,7 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove callbackRedirect := p.getOAuthRedirectURI(req) loginURL := p.provider.GetLoginURL( callbackRedirect, - encodeState(csrf.HashOAuthState(), appRedirect, p.encodeState), + encodeState(csrf.HashOAuthState(), appRedirect, retryCount, p.encodeState), csrf.HashOIDCNonce(), extraParams, ) @@ -864,22 +891,33 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } - errorString := req.Form.Get("error") - if errorString != "" { - logger.Errorf("Error while parsing OAuth2 callback: %s", errorString) - message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString) - // Set the debug message and override the non debug message to be the same for this case - p.ErrorPage(rw, req, http.StatusForbidden, message, message) - return - } - nonce, appRedirect, err := decodeState(req.Form.Get("state"), p.encodeState) + nonce, appRedirect, retries, err := decodeState(req.Form.Get("state"), p.encodeState) if err != nil { logger.Errorf("Error while parsing OAuth2 state: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } + errorString := req.Form.Get("error") + if errorString != "" { + message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString) + logger.Error(message) + + if slices.Contains(p.idpErrorsToRetry, errorString) { + restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries) + + if restartErr != nil { + logger.Errorf("Encountered Error while restarting: %s", restartErr) + p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString) + } + return + } + // Set the debug message and override the non debug message to be the same for this case + p.ErrorPage(rw, req, http.StatusForbidden, message, message) + return + } + // calculate the cookie name cookieName := cookies.GenerateCookieName(p.CookieOptions, nonce) // Try to find the CSRF cookie and decode it @@ -888,8 +926,16 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { // There are a lot of issues opened complaining about missing CSRF cookies. // Try to log the INs and OUTs of OAuthProxy, to be easier to analyse these issues. LoggingCSRFCookiesInOAuthCallback(req, cookieName) - logger.Println(req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce) - p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.") + message := fmt.Sprintf("Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce) + logger.Println(req, logger.AuthFailure, message) + + restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries) + + if restartErr != nil { + logger.Errorf("Encountered Error while restarting: %s", restartErr) + p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString) + } + return } @@ -932,7 +978,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { logger.Errorf("Error with authorization: %v", err) } if p.Validator(session.Email) && authorized { - logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) + logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2 after %v retries: %s", retries, session) err := p.SaveSession(rw, req, session) if err != nil { logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) @@ -1029,7 +1075,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // start OAuth flow, but only with the default login URL params - do not // consider this request's query params as potential overrides, since // the user did not explicitly start the login flow - p.doOAuthStart(rw, req, nil) + p.doOAuthStart(rw, req, nil, 0) } else { p.SignInPage(rw, req, http.StatusForbidden) } @@ -1242,8 +1288,9 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool { // encodeState builds the OAuth state param out of our nonce and // original application redirect -func encodeState(nonce string, redirect string, encode bool) string { - rawString := fmt.Sprintf("%v:%v", nonce, redirect) +func encodeState(nonce string, redirect string, retryCount int, encode bool) string { + redirectEscaped := url.QueryEscape(redirect) + rawString := fmt.Sprintf("%v:%v:%v", nonce, redirectEscaped, retryCount) if encode { return base64.RawURLEncoding.EncodeToString([]byte(rawString)) } @@ -1252,18 +1299,36 @@ func encodeState(nonce string, redirect string, encode bool) string { // decodeState splits the reflected OAuth state response back into // the nonce and original application redirect -func decodeState(state string, encode bool) (string, string, error) { +func decodeState(state string, encode bool) (string, string, int, error) { toParse := state if encode { decoded, _ := base64.RawURLEncoding.DecodeString(state) toParse = string(decoded) } - parsedState := strings.SplitN(toParse, ":", 2) - if len(parsedState) != 2 { - return "", "", errors.New("invalid length") + parsedState := strings.Split(toParse, ":") + numStateParams := len(parsedState) + + if numStateParams < 2 || numStateParams > 3 { + return "", "", -1, errors.New(fmt.Sprintf("invalid number of state parameters (%v)", numStateParams)) } - return parsedState[0], parsedState[1], nil + + nonce := parsedState[0] + + redirectPath, err := url.QueryUnescape(parsedState[1]) + if err != nil { + return "", "", -1, errors.New(fmt.Sprintf("invalid redirectPath url (%v)", parsedState[1])) + } + + retries := 0 + if numStateParams > 2 { + retries, err = strconv.Atoi(parsedState[2]) + if err != nil { + return "", "", -1, errors.New(fmt.Sprintf("invalid retry count (%v)", parsedState[2])) + } + } + + return nonce, redirectPath, retries, nil } // addHeadersForProxying adds the appropriate headers the request / response for proxying diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 488b8cea..15240ee7 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -413,7 +413,7 @@ func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie http.MethodGet, fmt.Sprintf( "/oauth2/callback?code=callback_code&state=%s", - encodeState(csrf.HashOAuthState(), "%2F", false), + encodeState(csrf.HashOAuthState(), "%2F", 0, false), ), strings.NewReader(""), ) @@ -3294,24 +3294,61 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) { func TestStateEncodesCorrectly(t *testing.T) { state := "some_state_to_test" nonce := "some_nonce_to_test" + retryCount := 3 - encodedResult := encodeState(nonce, state, true) - assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult) + encodedResult := encodeState(nonce, state, retryCount, true) + assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", encodedResult) - notEncodedResult := encodeState(nonce, state, false) - assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult) + notEncodedResult := encodeState(nonce, state, retryCount, false) + assert.Equal(t, "some_nonce_to_test:some_state_to_test:3", notEncodedResult) +} + +func TestStateEncodesWithColonInRedirect(t *testing.T) { + redirect := "test:url" + nonce := "nonce_value" + retryCount := 3 + + encodedResult := encodeState(nonce, redirect, retryCount, true) + assert.Equal(t, "bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", encodedResult) + + notEncodedResult := encodeState(nonce, redirect, retryCount, false) + assert.Equal(t, "nonce_value:test%3Aurl:3", notEncodedResult) } func TestStateDecodesCorrectly(t *testing.T) { - nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true) + nonce, redirect, retryCount, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", true) assert.Equal(t, "some_nonce_to_test", nonce) assert.Equal(t, "some_state_to_test", redirect) + assert.Equal(t, 3, retryCount) - nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false) + nonce2, redirect2, retryCount2, _ := decodeState("some_nonce_to_test:some_state_to_test:3", false) assert.Equal(t, "some_nonce_to_test", nonce2) assert.Equal(t, "some_state_to_test", redirect2) + assert.Equal(t, 3, retryCount2) +} + +func TestStateDecodesWithColonInRedirect(t *testing.T) { + nonce, redirect, retryCount, _ := decodeState("bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", true) + + assert.Equal(t, "nonce_value", nonce) + assert.Equal(t, "test:url", redirect) + assert.Equal(t, 3, retryCount) + + nonce2, redirect2, retryCount2, _ := decodeState("nonce_value:test%3Aurl:3", false) + + assert.Equal(t, "nonce_value", nonce2) + assert.Equal(t, "test:url", redirect2) + assert.Equal(t, 3, retryCount2) +} + +func TestStateDecodesWithoutRetryCount(t *testing.T) { + nonce, redirect, retryCount, _ := decodeState("nonce_value:url", false) + + assert.Equal(t, "nonce_value", nonce) + assert.Equal(t, "url", redirect) + assert.Equal(t, 0, retryCount) } func TestAuthOnlyAllowedEmails(t *testing.T) { diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 8fa72c7c..34816cdc 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -34,6 +34,9 @@ type Options struct { WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"` HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` HtpasswdUserGroups []string `flag:"htpasswd-user-group" cfg:"htpasswd_user_groups"` + MaxAutomatedRetries int `flag:"max-automated-retries" cfg:"max_automated_retries"` + IdpErrorsToRetry []string `flag:"retry-idp-errors" cfg:"retry_idp_errors"` + RetryCsrfErrors bool `flag:"retry-csrf-errors" cfg:"retry_csrf_errors"` Cookie Cookie `cfg:",squash"` Session SessionOptions `cfg:",squash"` @@ -161,6 +164,9 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") + flagSet.Int("max-automated-retries", 0, "Maximum number of automated retries for callback errors") + flagSet.StringSlice("retry-idp-errors", []string{}, "Errors from IdP that should be automatically retried") + flagSet.Bool("retry-csrf-errors", false, "If true retries Csrf errors automatically") flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(loggingFlagSet())