This commit is contained in:
TimFiernkranz 2025-09-30 13:26:28 +03:00 committed by GitHub
commit d08534f8d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 149 additions and 41 deletions

View File

@ -13,6 +13,8 @@ import (
"os" "os"
"os/signal" "os/signal"
"regexp" "regexp"
"slices"
"strconv"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@ -113,7 +115,10 @@ type OAuthProxy struct {
redirectValidator redirect.Validator redirectValidator redirect.Validator
appDirector redirect.AppDirector appDirector redirect.AppDirector
encodeState bool encodeState bool
maxAutomatedRetries int
idpErrorsToRetry []string
retryCsrfErrors bool
} }
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
@ -238,16 +243,19 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
allowQuerySemicolons: opts.AllowQuerySemicolons, allowQuerySemicolons: opts.AllowQuerySemicolons,
trustedIPs: trustedIPs, trustedIPs: trustedIPs,
basicAuthValidator: basicAuthValidator, basicAuthValidator: basicAuthValidator,
basicAuthGroups: opts.HtpasswdUserGroups, basicAuthGroups: opts.HtpasswdUserGroups,
sessionChain: sessionChain, sessionChain: sessionChain,
headersChain: headersChain, headersChain: headersChain,
preAuthChain: preAuthChain, preAuthChain: preAuthChain,
pageWriter: pageWriter, pageWriter: pageWriter,
upstreamProxy: upstreamProxy, upstreamProxy: upstreamProxy,
redirectValidator: redirectValidator, redirectValidator: redirectValidator,
appDirector: appDirector, appDirector: appDirector,
encodeState: opts.EncodeState, encodeState: opts.EncodeState,
maxAutomatedRetries: opts.MaxAutomatedRetries,
idpErrorsToRetry: opts.IdpErrorsToRetry,
retryCsrfErrors: opts.RetryCsrfErrors,
} }
p.buildServeMux(opts.ProxyPrefix) p.buildServeMux(opts.ProxyPrefix)
@ -788,13 +796,32 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
} }
} }
// OAuthRestart restarts the OAuth2 authentication flow with a specified retryCounter
func (p *OAuthProxy) OAuthRestart(rw http.ResponseWriter, req *http.Request, redirectPath string, retryCount int) (error, string) {
if !p.redirectValidator.IsValidRedirect(redirectPath) {
return errors.New("invalid redirect"), fmt.Sprintf("Login Failed: The given redirect (%v) was not valid. Please try again.", redirectPath)
}
if retryCount >= p.maxAutomatedRetries {
return errors.New("retries exceeded"), fmt.Sprintf("Login Failed: The maximum amount (%v) of automated retries exceeded.", p.maxAutomatedRetries)
}
logger.Printf("Restarting Oauth2 flow (%v/%v)", retryCount, p.maxAutomatedRetries)
// Handlers typically use this parameter to preserve the login redirect path.
// Since initialization now occurs within the proxy, we must set this parameter
// to retain the originally requested path for restarted requests.
req.Form.Set("rd", redirectPath)
p.doOAuthStart(rw, req, nil, retryCount+1)
return nil, ""
}
// OAuthStart starts the OAuth2 authentication flow // OAuthStart starts the OAuth2 authentication flow
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
// start the flow permitting login URL query parameters to be overridden from the request URL // start the flow permitting login URL query parameters to be overridden from the request URL
p.doOAuthStart(rw, req, req.URL.Query()) p.doOAuthStart(rw, req, req.URL.Query(), 0)
} }
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values) { func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values, retryCount int) {
extraParams := p.provider.Data().LoginURLParams(overrides) extraParams := p.provider.Data().LoginURLParams(overrides)
prepareNoCache(rw) prepareNoCache(rw)
@ -839,7 +866,7 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove
callbackRedirect := p.getOAuthRedirectURI(req) callbackRedirect := p.getOAuthRedirectURI(req)
loginURL := p.provider.GetLoginURL( loginURL := p.provider.GetLoginURL(
callbackRedirect, callbackRedirect,
encodeState(csrf.HashOAuthState(), appRedirect, p.encodeState), encodeState(csrf.HashOAuthState(), appRedirect, retryCount, p.encodeState),
csrf.HashOIDCNonce(), csrf.HashOIDCNonce(),
extraParams, extraParams,
) )
@ -864,22 +891,33 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return return
} }
errorString := req.Form.Get("error")
if errorString != "" {
logger.Errorf("Error while parsing OAuth2 callback: %s", errorString)
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
// Set the debug message and override the non debug message to be the same for this case
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
return
}
nonce, appRedirect, err := decodeState(req.Form.Get("state"), p.encodeState) nonce, appRedirect, retries, err := decodeState(req.Form.Get("state"), p.encodeState)
if err != nil { if err != nil {
logger.Errorf("Error while parsing OAuth2 state: %v", err) logger.Errorf("Error while parsing OAuth2 state: %v", err)
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return return
} }
errorString := req.Form.Get("error")
if errorString != "" {
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
logger.Error(message)
if slices.Contains(p.idpErrorsToRetry, errorString) {
restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries)
if restartErr != nil {
logger.Errorf("Encountered Error while restarting: %s", restartErr)
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
}
return
}
// Set the debug message and override the non debug message to be the same for this case
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
return
}
// calculate the cookie name // calculate the cookie name
cookieName := cookies.GenerateCookieName(p.CookieOptions, nonce) cookieName := cookies.GenerateCookieName(p.CookieOptions, nonce)
// Try to find the CSRF cookie and decode it // Try to find the CSRF cookie and decode it
@ -888,8 +926,16 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
// There are a lot of issues opened complaining about missing CSRF cookies. // There are a lot of issues opened complaining about missing CSRF cookies.
// Try to log the INs and OUTs of OAuthProxy, to be easier to analyse these issues. // Try to log the INs and OUTs of OAuthProxy, to be easier to analyse these issues.
LoggingCSRFCookiesInOAuthCallback(req, cookieName) LoggingCSRFCookiesInOAuthCallback(req, cookieName)
logger.Println(req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce) message := fmt.Sprintf("Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.") logger.Println(req, logger.AuthFailure, message)
restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries)
if restartErr != nil {
logger.Errorf("Encountered Error while restarting: %s", restartErr)
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
}
return return
} }
@ -932,7 +978,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
logger.Errorf("Error with authorization: %v", err) logger.Errorf("Error with authorization: %v", err)
} }
if p.Validator(session.Email) && authorized { if p.Validator(session.Email) && authorized {
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2 after %v retries: %s", retries, session)
err := p.SaveSession(rw, req, session) err := p.SaveSession(rw, req, session)
if err != nil { if err != nil {
logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
@ -1029,7 +1075,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
// start OAuth flow, but only with the default login URL params - do not // start OAuth flow, but only with the default login URL params - do not
// consider this request's query params as potential overrides, since // consider this request's query params as potential overrides, since
// the user did not explicitly start the login flow // the user did not explicitly start the login flow
p.doOAuthStart(rw, req, nil) p.doOAuthStart(rw, req, nil, 0)
} else { } else {
p.SignInPage(rw, req, http.StatusForbidden) p.SignInPage(rw, req, http.StatusForbidden)
} }
@ -1242,8 +1288,9 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool {
// encodeState builds the OAuth state param out of our nonce and // encodeState builds the OAuth state param out of our nonce and
// original application redirect // original application redirect
func encodeState(nonce string, redirect string, encode bool) string { func encodeState(nonce string, redirect string, retryCount int, encode bool) string {
rawString := fmt.Sprintf("%v:%v", nonce, redirect) redirectEscaped := url.QueryEscape(redirect)
rawString := fmt.Sprintf("%v:%v:%v", nonce, redirectEscaped, retryCount)
if encode { if encode {
return base64.RawURLEncoding.EncodeToString([]byte(rawString)) return base64.RawURLEncoding.EncodeToString([]byte(rawString))
} }
@ -1252,18 +1299,36 @@ func encodeState(nonce string, redirect string, encode bool) string {
// decodeState splits the reflected OAuth state response back into // decodeState splits the reflected OAuth state response back into
// the nonce and original application redirect // the nonce and original application redirect
func decodeState(state string, encode bool) (string, string, error) { func decodeState(state string, encode bool) (string, string, int, error) {
toParse := state toParse := state
if encode { if encode {
decoded, _ := base64.RawURLEncoding.DecodeString(state) decoded, _ := base64.RawURLEncoding.DecodeString(state)
toParse = string(decoded) toParse = string(decoded)
} }
parsedState := strings.SplitN(toParse, ":", 2) parsedState := strings.Split(toParse, ":")
if len(parsedState) != 2 { numStateParams := len(parsedState)
return "", "", errors.New("invalid length")
if numStateParams < 2 || numStateParams > 3 {
return "", "", -1, errors.New(fmt.Sprintf("invalid number of state parameters (%v)", numStateParams))
} }
return parsedState[0], parsedState[1], nil
nonce := parsedState[0]
redirectPath, err := url.QueryUnescape(parsedState[1])
if err != nil {
return "", "", -1, errors.New(fmt.Sprintf("invalid redirectPath url (%v)", parsedState[1]))
}
retries := 0
if numStateParams > 2 {
retries, err = strconv.Atoi(parsedState[2])
if err != nil {
return "", "", -1, errors.New(fmt.Sprintf("invalid retry count (%v)", parsedState[2]))
}
}
return nonce, redirectPath, retries, nil
} }
// addHeadersForProxying adds the appropriate headers the request / response for proxying // addHeadersForProxying adds the appropriate headers the request / response for proxying

View File

@ -413,7 +413,7 @@ func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie
http.MethodGet, http.MethodGet,
fmt.Sprintf( fmt.Sprintf(
"/oauth2/callback?code=callback_code&state=%s", "/oauth2/callback?code=callback_code&state=%s",
encodeState(csrf.HashOAuthState(), "%2F", false), encodeState(csrf.HashOAuthState(), "%2F", 0, false),
), ),
strings.NewReader(""), strings.NewReader(""),
) )
@ -3294,24 +3294,61 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) {
func TestStateEncodesCorrectly(t *testing.T) { func TestStateEncodesCorrectly(t *testing.T) {
state := "some_state_to_test" state := "some_state_to_test"
nonce := "some_nonce_to_test" nonce := "some_nonce_to_test"
retryCount := 3
encodedResult := encodeState(nonce, state, true) encodedResult := encodeState(nonce, state, retryCount, true)
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult) assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", encodedResult)
notEncodedResult := encodeState(nonce, state, false) notEncodedResult := encodeState(nonce, state, retryCount, false)
assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult) assert.Equal(t, "some_nonce_to_test:some_state_to_test:3", notEncodedResult)
}
func TestStateEncodesWithColonInRedirect(t *testing.T) {
redirect := "test:url"
nonce := "nonce_value"
retryCount := 3
encodedResult := encodeState(nonce, redirect, retryCount, true)
assert.Equal(t, "bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", encodedResult)
notEncodedResult := encodeState(nonce, redirect, retryCount, false)
assert.Equal(t, "nonce_value:test%3Aurl:3", notEncodedResult)
} }
func TestStateDecodesCorrectly(t *testing.T) { func TestStateDecodesCorrectly(t *testing.T) {
nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true) nonce, redirect, retryCount, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", true)
assert.Equal(t, "some_nonce_to_test", nonce) assert.Equal(t, "some_nonce_to_test", nonce)
assert.Equal(t, "some_state_to_test", redirect) assert.Equal(t, "some_state_to_test", redirect)
assert.Equal(t, 3, retryCount)
nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false) nonce2, redirect2, retryCount2, _ := decodeState("some_nonce_to_test:some_state_to_test:3", false)
assert.Equal(t, "some_nonce_to_test", nonce2) assert.Equal(t, "some_nonce_to_test", nonce2)
assert.Equal(t, "some_state_to_test", redirect2) assert.Equal(t, "some_state_to_test", redirect2)
assert.Equal(t, 3, retryCount2)
}
func TestStateDecodesWithColonInRedirect(t *testing.T) {
nonce, redirect, retryCount, _ := decodeState("bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", true)
assert.Equal(t, "nonce_value", nonce)
assert.Equal(t, "test:url", redirect)
assert.Equal(t, 3, retryCount)
nonce2, redirect2, retryCount2, _ := decodeState("nonce_value:test%3Aurl:3", false)
assert.Equal(t, "nonce_value", nonce2)
assert.Equal(t, "test:url", redirect2)
assert.Equal(t, 3, retryCount2)
}
func TestStateDecodesWithoutRetryCount(t *testing.T) {
nonce, redirect, retryCount, _ := decodeState("nonce_value:url", false)
assert.Equal(t, "nonce_value", nonce)
assert.Equal(t, "url", redirect)
assert.Equal(t, 0, retryCount)
} }
func TestAuthOnlyAllowedEmails(t *testing.T) { func TestAuthOnlyAllowedEmails(t *testing.T) {

View File

@ -34,6 +34,9 @@ type Options struct {
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"` WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"`
HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"`
HtpasswdUserGroups []string `flag:"htpasswd-user-group" cfg:"htpasswd_user_groups"` HtpasswdUserGroups []string `flag:"htpasswd-user-group" cfg:"htpasswd_user_groups"`
MaxAutomatedRetries int `flag:"max-automated-retries" cfg:"max_automated_retries"`
IdpErrorsToRetry []string `flag:"retry-idp-errors" cfg:"retry_idp_errors"`
RetryCsrfErrors bool `flag:"retry-csrf-errors" cfg:"retry_csrf_errors"`
Cookie Cookie `cfg:",squash"` Cookie Cookie `cfg:",squash"`
Session SessionOptions `cfg:",squash"` Session SessionOptions `cfg:",squash"`
@ -161,6 +164,9 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option") flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option")
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
flagSet.Int("max-automated-retries", 0, "Maximum number of automated retries for callback errors")
flagSet.StringSlice("retry-idp-errors", []string{}, "Errors from IdP that should be automatically retried")
flagSet.Bool("retry-csrf-errors", false, "If true retries Csrf errors automatically")
flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(cookieFlagSet())
flagSet.AddFlagSet(loggingFlagSet()) flagSet.AddFlagSet(loggingFlagSet())