Add automated retries for IdP and CSRF errors
This implements functionality to automate restarting a login. The automation of CSRF errors catches missing CSRF cookies to retry as fallback for Logins that required too much time. The IdP errors are implemented because some IdP Providers return with errors where it is advised to directly retry the login process. For example keycloak handles its "temporarily_unavailable" error like this. See https://www.keycloak.org/securing-apps/oidc-layers#_oidc-errors
This commit is contained in:
parent
8afb047e01
commit
bad29b556e
133
oauthproxy.go
133
oauthproxy.go
|
|
@ -13,6 +13,8 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
|
@ -113,7 +115,10 @@ type OAuthProxy struct {
|
|||
redirectValidator redirect.Validator
|
||||
appDirector redirect.AppDirector
|
||||
|
||||
encodeState bool
|
||||
encodeState bool
|
||||
maxAutomatedRetries int
|
||||
idpErrorsToRetry []string
|
||||
retryCsrfErrors bool
|
||||
}
|
||||
|
||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
||||
|
|
@ -238,16 +243,19 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||
allowQuerySemicolons: opts.AllowQuerySemicolons,
|
||||
trustedIPs: trustedIPs,
|
||||
|
||||
basicAuthValidator: basicAuthValidator,
|
||||
basicAuthGroups: opts.HtpasswdUserGroups,
|
||||
sessionChain: sessionChain,
|
||||
headersChain: headersChain,
|
||||
preAuthChain: preAuthChain,
|
||||
pageWriter: pageWriter,
|
||||
upstreamProxy: upstreamProxy,
|
||||
redirectValidator: redirectValidator,
|
||||
appDirector: appDirector,
|
||||
encodeState: opts.EncodeState,
|
||||
basicAuthValidator: basicAuthValidator,
|
||||
basicAuthGroups: opts.HtpasswdUserGroups,
|
||||
sessionChain: sessionChain,
|
||||
headersChain: headersChain,
|
||||
preAuthChain: preAuthChain,
|
||||
pageWriter: pageWriter,
|
||||
upstreamProxy: upstreamProxy,
|
||||
redirectValidator: redirectValidator,
|
||||
appDirector: appDirector,
|
||||
encodeState: opts.EncodeState,
|
||||
maxAutomatedRetries: opts.MaxAutomatedRetries,
|
||||
idpErrorsToRetry: opts.IdpErrorsToRetry,
|
||||
retryCsrfErrors: opts.RetryCsrfErrors,
|
||||
}
|
||||
p.buildServeMux(opts.ProxyPrefix)
|
||||
|
||||
|
|
@ -788,13 +796,32 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// OAuthRestart restarts the OAuth2 authentication flow with a specified retryCounter
|
||||
func (p *OAuthProxy) OAuthRestart(rw http.ResponseWriter, req *http.Request, redirectPath string, retryCount int) (error, string) {
|
||||
if !p.redirectValidator.IsValidRedirect(redirectPath) {
|
||||
return errors.New("invalid redirect"), fmt.Sprintf("Login Failed: The given redirect (%v) was not valid. Please try again.", redirectPath)
|
||||
}
|
||||
|
||||
if retryCount >= p.maxAutomatedRetries {
|
||||
return errors.New("retries exceeded"), fmt.Sprintf("Login Failed: The maximum amount (%v) of automated retries exceeded.", p.maxAutomatedRetries)
|
||||
}
|
||||
|
||||
logger.Printf("Restarting Oauth2 flow (%v/%v)", retryCount, p.maxAutomatedRetries)
|
||||
// Handlers typically use this parameter to preserve the login redirect path.
|
||||
// Since initialization now occurs within the proxy, we must set this parameter
|
||||
// to retain the originally requested path for restarted requests.
|
||||
req.Form.Set("rd", redirectPath)
|
||||
p.doOAuthStart(rw, req, nil, retryCount+1)
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// OAuthStart starts the OAuth2 authentication flow
|
||||
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||
// start the flow permitting login URL query parameters to be overridden from the request URL
|
||||
p.doOAuthStart(rw, req, req.URL.Query())
|
||||
p.doOAuthStart(rw, req, req.URL.Query(), 0)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values) {
|
||||
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values, retryCount int) {
|
||||
extraParams := p.provider.Data().LoginURLParams(overrides)
|
||||
prepareNoCache(rw)
|
||||
|
||||
|
|
@ -839,7 +866,7 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove
|
|||
callbackRedirect := p.getOAuthRedirectURI(req)
|
||||
loginURL := p.provider.GetLoginURL(
|
||||
callbackRedirect,
|
||||
encodeState(csrf.HashOAuthState(), appRedirect, p.encodeState),
|
||||
encodeState(csrf.HashOAuthState(), appRedirect, retryCount, p.encodeState),
|
||||
csrf.HashOIDCNonce(),
|
||||
extraParams,
|
||||
)
|
||||
|
|
@ -864,22 +891,33 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
errorString := req.Form.Get("error")
|
||||
if errorString != "" {
|
||||
logger.Errorf("Error while parsing OAuth2 callback: %s", errorString)
|
||||
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
|
||||
// Set the debug message and override the non debug message to be the same for this case
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
|
||||
return
|
||||
}
|
||||
|
||||
nonce, appRedirect, err := decodeState(req.Form.Get("state"), p.encodeState)
|
||||
nonce, appRedirect, retries, err := decodeState(req.Form.Get("state"), p.encodeState)
|
||||
if err != nil {
|
||||
logger.Errorf("Error while parsing OAuth2 state: %v", err)
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
errorString := req.Form.Get("error")
|
||||
if errorString != "" {
|
||||
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
|
||||
logger.Error(message)
|
||||
|
||||
if slices.Contains(p.idpErrorsToRetry, errorString) {
|
||||
restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries)
|
||||
|
||||
if restartErr != nil {
|
||||
logger.Errorf("Encountered Error while restarting: %s", restartErr)
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Set the debug message and override the non debug message to be the same for this case
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
|
||||
return
|
||||
}
|
||||
|
||||
// calculate the cookie name
|
||||
cookieName := cookies.GenerateCookieName(p.CookieOptions, nonce)
|
||||
// Try to find the CSRF cookie and decode it
|
||||
|
|
@ -888,8 +926,16 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||
// There are a lot of issues opened complaining about missing CSRF cookies.
|
||||
// Try to log the INs and OUTs of OAuthProxy, to be easier to analyse these issues.
|
||||
LoggingCSRFCookiesInOAuthCallback(req, cookieName)
|
||||
logger.Println(req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.")
|
||||
message := fmt.Sprintf("Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
|
||||
logger.Println(req, logger.AuthFailure, message)
|
||||
|
||||
restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries)
|
||||
|
||||
if restartErr != nil {
|
||||
logger.Errorf("Encountered Error while restarting: %s", restartErr)
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -932,7 +978,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||
logger.Errorf("Error with authorization: %v", err)
|
||||
}
|
||||
if p.Validator(session.Email) && authorized {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2 after %v retries: %s", retries, session)
|
||||
err := p.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
|
||||
|
|
@ -1029,7 +1075,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
|||
// start OAuth flow, but only with the default login URL params - do not
|
||||
// consider this request's query params as potential overrides, since
|
||||
// the user did not explicitly start the login flow
|
||||
p.doOAuthStart(rw, req, nil)
|
||||
p.doOAuthStart(rw, req, nil, 0)
|
||||
} else {
|
||||
p.SignInPage(rw, req, http.StatusForbidden)
|
||||
}
|
||||
|
|
@ -1242,8 +1288,9 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool {
|
|||
|
||||
// encodeState builds the OAuth state param out of our nonce and
|
||||
// original application redirect
|
||||
func encodeState(nonce string, redirect string, encode bool) string {
|
||||
rawString := fmt.Sprintf("%v:%v", nonce, redirect)
|
||||
func encodeState(nonce string, redirect string, retryCount int, encode bool) string {
|
||||
redirectEscaped := url.QueryEscape(redirect)
|
||||
rawString := fmt.Sprintf("%v:%v:%v", nonce, redirectEscaped, retryCount)
|
||||
if encode {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(rawString))
|
||||
}
|
||||
|
|
@ -1252,18 +1299,36 @@ func encodeState(nonce string, redirect string, encode bool) string {
|
|||
|
||||
// decodeState splits the reflected OAuth state response back into
|
||||
// the nonce and original application redirect
|
||||
func decodeState(state string, encode bool) (string, string, error) {
|
||||
func decodeState(state string, encode bool) (string, string, int, error) {
|
||||
toParse := state
|
||||
if encode {
|
||||
decoded, _ := base64.RawURLEncoding.DecodeString(state)
|
||||
toParse = string(decoded)
|
||||
}
|
||||
|
||||
parsedState := strings.SplitN(toParse, ":", 2)
|
||||
if len(parsedState) != 2 {
|
||||
return "", "", errors.New("invalid length")
|
||||
parsedState := strings.Split(toParse, ":")
|
||||
numStateParams := len(parsedState)
|
||||
|
||||
if numStateParams < 2 || numStateParams > 3 {
|
||||
return "", "", -1, errors.New(fmt.Sprintf("invalid number of state parameters (%v)", numStateParams))
|
||||
}
|
||||
return parsedState[0], parsedState[1], nil
|
||||
|
||||
nonce := parsedState[0]
|
||||
|
||||
redirectPath, err := url.QueryUnescape(parsedState[1])
|
||||
if err != nil {
|
||||
return "", "", -1, errors.New(fmt.Sprintf("invalid redirectPath url (%v)", parsedState[1]))
|
||||
}
|
||||
|
||||
retries := 0
|
||||
if numStateParams > 2 {
|
||||
retries, err = strconv.Atoi(parsedState[2])
|
||||
if err != nil {
|
||||
return "", "", -1, errors.New(fmt.Sprintf("invalid retry count (%v)", parsedState[2]))
|
||||
}
|
||||
}
|
||||
|
||||
return nonce, redirectPath, retries, nil
|
||||
}
|
||||
|
||||
// addHeadersForProxying adds the appropriate headers the request / response for proxying
|
||||
|
|
|
|||
|
|
@ -413,7 +413,7 @@ func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie
|
|||
http.MethodGet,
|
||||
fmt.Sprintf(
|
||||
"/oauth2/callback?code=callback_code&state=%s",
|
||||
encodeState(csrf.HashOAuthState(), "%2F", false),
|
||||
encodeState(csrf.HashOAuthState(), "%2F", 0, false),
|
||||
),
|
||||
strings.NewReader(""),
|
||||
)
|
||||
|
|
@ -3294,24 +3294,61 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) {
|
|||
func TestStateEncodesCorrectly(t *testing.T) {
|
||||
state := "some_state_to_test"
|
||||
nonce := "some_nonce_to_test"
|
||||
retryCount := 3
|
||||
|
||||
encodedResult := encodeState(nonce, state, true)
|
||||
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult)
|
||||
encodedResult := encodeState(nonce, state, retryCount, true)
|
||||
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", encodedResult)
|
||||
|
||||
notEncodedResult := encodeState(nonce, state, false)
|
||||
assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult)
|
||||
notEncodedResult := encodeState(nonce, state, retryCount, false)
|
||||
assert.Equal(t, "some_nonce_to_test:some_state_to_test:3", notEncodedResult)
|
||||
}
|
||||
|
||||
func TestStateEncodesWithColonInRedirect(t *testing.T) {
|
||||
redirect := "test:url"
|
||||
nonce := "nonce_value"
|
||||
retryCount := 3
|
||||
|
||||
encodedResult := encodeState(nonce, redirect, retryCount, true)
|
||||
assert.Equal(t, "bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", encodedResult)
|
||||
|
||||
notEncodedResult := encodeState(nonce, redirect, retryCount, false)
|
||||
assert.Equal(t, "nonce_value:test%3Aurl:3", notEncodedResult)
|
||||
}
|
||||
|
||||
func TestStateDecodesCorrectly(t *testing.T) {
|
||||
nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true)
|
||||
nonce, redirect, retryCount, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", true)
|
||||
|
||||
assert.Equal(t, "some_nonce_to_test", nonce)
|
||||
assert.Equal(t, "some_state_to_test", redirect)
|
||||
assert.Equal(t, 3, retryCount)
|
||||
|
||||
nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false)
|
||||
nonce2, redirect2, retryCount2, _ := decodeState("some_nonce_to_test:some_state_to_test:3", false)
|
||||
|
||||
assert.Equal(t, "some_nonce_to_test", nonce2)
|
||||
assert.Equal(t, "some_state_to_test", redirect2)
|
||||
assert.Equal(t, 3, retryCount2)
|
||||
}
|
||||
|
||||
func TestStateDecodesWithColonInRedirect(t *testing.T) {
|
||||
nonce, redirect, retryCount, _ := decodeState("bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", true)
|
||||
|
||||
assert.Equal(t, "nonce_value", nonce)
|
||||
assert.Equal(t, "test:url", redirect)
|
||||
assert.Equal(t, 3, retryCount)
|
||||
|
||||
nonce2, redirect2, retryCount2, _ := decodeState("nonce_value:test%3Aurl:3", false)
|
||||
|
||||
assert.Equal(t, "nonce_value", nonce2)
|
||||
assert.Equal(t, "test:url", redirect2)
|
||||
assert.Equal(t, 3, retryCount2)
|
||||
}
|
||||
|
||||
func TestStateDecodesWithoutRetryCount(t *testing.T) {
|
||||
nonce, redirect, retryCount, _ := decodeState("nonce_value:url", false)
|
||||
|
||||
assert.Equal(t, "nonce_value", nonce)
|
||||
assert.Equal(t, "url", redirect)
|
||||
assert.Equal(t, 0, retryCount)
|
||||
}
|
||||
|
||||
func TestAuthOnlyAllowedEmails(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ type Options struct {
|
|||
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"`
|
||||
HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"`
|
||||
HtpasswdUserGroups []string `flag:"htpasswd-user-group" cfg:"htpasswd_user_groups"`
|
||||
MaxAutomatedRetries int `flag:"max-automated-retries" cfg:"max_automated_retries"`
|
||||
IdpErrorsToRetry []string `flag:"retry-idp-errors" cfg:"retry_idp_errors"`
|
||||
RetryCsrfErrors bool `flag:"retry-csrf-errors" cfg:"retry_csrf_errors"`
|
||||
|
||||
Cookie Cookie `cfg:",squash"`
|
||||
Session SessionOptions `cfg:",squash"`
|
||||
|
|
@ -161,6 +164,9 @@ func NewFlagSet() *pflag.FlagSet {
|
|||
flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option")
|
||||
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
|
||||
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
|
||||
flagSet.Int("max-automated-retries", 0, "Maximum number of automated retries for callback errors")
|
||||
flagSet.StringSlice("retry-idp-errors", []string{}, "Errors from IdP that should be automatically retried")
|
||||
flagSet.Bool("retry-csrf-errors", false, "If true retries Csrf errors automatically")
|
||||
|
||||
flagSet.AddFlagSet(cookieFlagSet())
|
||||
flagSet.AddFlagSet(loggingFlagSet())
|
||||
|
|
|
|||
Loading…
Reference in New Issue