This commit is contained in:
TimFiernkranz 2025-09-30 13:26:28 +03:00 committed by GitHub
commit d08534f8d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 149 additions and 41 deletions

View File

@ -13,6 +13,8 @@ import (
"os"
"os/signal"
"regexp"
"slices"
"strconv"
"strings"
"syscall"
"time"
@ -113,7 +115,10 @@ type OAuthProxy struct {
redirectValidator redirect.Validator
appDirector redirect.AppDirector
encodeState bool
encodeState bool
maxAutomatedRetries int
idpErrorsToRetry []string
retryCsrfErrors bool
}
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
@ -238,16 +243,19 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
allowQuerySemicolons: opts.AllowQuerySemicolons,
trustedIPs: trustedIPs,
basicAuthValidator: basicAuthValidator,
basicAuthGroups: opts.HtpasswdUserGroups,
sessionChain: sessionChain,
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
upstreamProxy: upstreamProxy,
redirectValidator: redirectValidator,
appDirector: appDirector,
encodeState: opts.EncodeState,
basicAuthValidator: basicAuthValidator,
basicAuthGroups: opts.HtpasswdUserGroups,
sessionChain: sessionChain,
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
upstreamProxy: upstreamProxy,
redirectValidator: redirectValidator,
appDirector: appDirector,
encodeState: opts.EncodeState,
maxAutomatedRetries: opts.MaxAutomatedRetries,
idpErrorsToRetry: opts.IdpErrorsToRetry,
retryCsrfErrors: opts.RetryCsrfErrors,
}
p.buildServeMux(opts.ProxyPrefix)
@ -788,13 +796,32 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
}
}
// OAuthRestart restarts the OAuth2 authentication flow with a specified retryCounter
func (p *OAuthProxy) OAuthRestart(rw http.ResponseWriter, req *http.Request, redirectPath string, retryCount int) (error, string) {
if !p.redirectValidator.IsValidRedirect(redirectPath) {
return errors.New("invalid redirect"), fmt.Sprintf("Login Failed: The given redirect (%v) was not valid. Please try again.", redirectPath)
}
if retryCount >= p.maxAutomatedRetries {
return errors.New("retries exceeded"), fmt.Sprintf("Login Failed: The maximum amount (%v) of automated retries exceeded.", p.maxAutomatedRetries)
}
logger.Printf("Restarting Oauth2 flow (%v/%v)", retryCount, p.maxAutomatedRetries)
// Handlers typically use this parameter to preserve the login redirect path.
// Since initialization now occurs within the proxy, we must set this parameter
// to retain the originally requested path for restarted requests.
req.Form.Set("rd", redirectPath)
p.doOAuthStart(rw, req, nil, retryCount+1)
return nil, ""
}
// OAuthStart starts the OAuth2 authentication flow
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
// start the flow permitting login URL query parameters to be overridden from the request URL
p.doOAuthStart(rw, req, req.URL.Query())
p.doOAuthStart(rw, req, req.URL.Query(), 0)
}
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values) {
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values, retryCount int) {
extraParams := p.provider.Data().LoginURLParams(overrides)
prepareNoCache(rw)
@ -839,7 +866,7 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove
callbackRedirect := p.getOAuthRedirectURI(req)
loginURL := p.provider.GetLoginURL(
callbackRedirect,
encodeState(csrf.HashOAuthState(), appRedirect, p.encodeState),
encodeState(csrf.HashOAuthState(), appRedirect, retryCount, p.encodeState),
csrf.HashOIDCNonce(),
extraParams,
)
@ -864,22 +891,33 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
errorString := req.Form.Get("error")
if errorString != "" {
logger.Errorf("Error while parsing OAuth2 callback: %s", errorString)
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
// Set the debug message and override the non debug message to be the same for this case
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
return
}
nonce, appRedirect, err := decodeState(req.Form.Get("state"), p.encodeState)
nonce, appRedirect, retries, err := decodeState(req.Form.Get("state"), p.encodeState)
if err != nil {
logger.Errorf("Error while parsing OAuth2 state: %v", err)
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
errorString := req.Form.Get("error")
if errorString != "" {
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
logger.Error(message)
if slices.Contains(p.idpErrorsToRetry, errorString) {
restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries)
if restartErr != nil {
logger.Errorf("Encountered Error while restarting: %s", restartErr)
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
}
return
}
// Set the debug message and override the non debug message to be the same for this case
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
return
}
// calculate the cookie name
cookieName := cookies.GenerateCookieName(p.CookieOptions, nonce)
// Try to find the CSRF cookie and decode it
@ -888,8 +926,16 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
// There are a lot of issues opened complaining about missing CSRF cookies.
// Try to log the INs and OUTs of OAuthProxy, to be easier to analyse these issues.
LoggingCSRFCookiesInOAuthCallback(req, cookieName)
logger.Println(req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.")
message := fmt.Sprintf("Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
logger.Println(req, logger.AuthFailure, message)
restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries)
if restartErr != nil {
logger.Errorf("Encountered Error while restarting: %s", restartErr)
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
}
return
}
@ -932,7 +978,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
logger.Errorf("Error with authorization: %v", err)
}
if p.Validator(session.Email) && authorized {
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2 after %v retries: %s", retries, session)
err := p.SaveSession(rw, req, session)
if err != nil {
logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
@ -1029,7 +1075,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
// start OAuth flow, but only with the default login URL params - do not
// consider this request's query params as potential overrides, since
// the user did not explicitly start the login flow
p.doOAuthStart(rw, req, nil)
p.doOAuthStart(rw, req, nil, 0)
} else {
p.SignInPage(rw, req, http.StatusForbidden)
}
@ -1242,8 +1288,9 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool {
// encodeState builds the OAuth state param out of our nonce and
// original application redirect
func encodeState(nonce string, redirect string, encode bool) string {
rawString := fmt.Sprintf("%v:%v", nonce, redirect)
func encodeState(nonce string, redirect string, retryCount int, encode bool) string {
redirectEscaped := url.QueryEscape(redirect)
rawString := fmt.Sprintf("%v:%v:%v", nonce, redirectEscaped, retryCount)
if encode {
return base64.RawURLEncoding.EncodeToString([]byte(rawString))
}
@ -1252,18 +1299,36 @@ func encodeState(nonce string, redirect string, encode bool) string {
// decodeState splits the reflected OAuth state response back into
// the nonce and original application redirect
func decodeState(state string, encode bool) (string, string, error) {
func decodeState(state string, encode bool) (string, string, int, error) {
toParse := state
if encode {
decoded, _ := base64.RawURLEncoding.DecodeString(state)
toParse = string(decoded)
}
parsedState := strings.SplitN(toParse, ":", 2)
if len(parsedState) != 2 {
return "", "", errors.New("invalid length")
parsedState := strings.Split(toParse, ":")
numStateParams := len(parsedState)
if numStateParams < 2 || numStateParams > 3 {
return "", "", -1, errors.New(fmt.Sprintf("invalid number of state parameters (%v)", numStateParams))
}
return parsedState[0], parsedState[1], nil
nonce := parsedState[0]
redirectPath, err := url.QueryUnescape(parsedState[1])
if err != nil {
return "", "", -1, errors.New(fmt.Sprintf("invalid redirectPath url (%v)", parsedState[1]))
}
retries := 0
if numStateParams > 2 {
retries, err = strconv.Atoi(parsedState[2])
if err != nil {
return "", "", -1, errors.New(fmt.Sprintf("invalid retry count (%v)", parsedState[2]))
}
}
return nonce, redirectPath, retries, nil
}
// addHeadersForProxying adds the appropriate headers the request / response for proxying

View File

@ -413,7 +413,7 @@ func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie
http.MethodGet,
fmt.Sprintf(
"/oauth2/callback?code=callback_code&state=%s",
encodeState(csrf.HashOAuthState(), "%2F", false),
encodeState(csrf.HashOAuthState(), "%2F", 0, false),
),
strings.NewReader(""),
)
@ -3294,24 +3294,61 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) {
func TestStateEncodesCorrectly(t *testing.T) {
state := "some_state_to_test"
nonce := "some_nonce_to_test"
retryCount := 3
encodedResult := encodeState(nonce, state, true)
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult)
encodedResult := encodeState(nonce, state, retryCount, true)
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", encodedResult)
notEncodedResult := encodeState(nonce, state, false)
assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult)
notEncodedResult := encodeState(nonce, state, retryCount, false)
assert.Equal(t, "some_nonce_to_test:some_state_to_test:3", notEncodedResult)
}
func TestStateEncodesWithColonInRedirect(t *testing.T) {
redirect := "test:url"
nonce := "nonce_value"
retryCount := 3
encodedResult := encodeState(nonce, redirect, retryCount, true)
assert.Equal(t, "bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", encodedResult)
notEncodedResult := encodeState(nonce, redirect, retryCount, false)
assert.Equal(t, "nonce_value:test%3Aurl:3", notEncodedResult)
}
func TestStateDecodesCorrectly(t *testing.T) {
nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true)
nonce, redirect, retryCount, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", true)
assert.Equal(t, "some_nonce_to_test", nonce)
assert.Equal(t, "some_state_to_test", redirect)
assert.Equal(t, 3, retryCount)
nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false)
nonce2, redirect2, retryCount2, _ := decodeState("some_nonce_to_test:some_state_to_test:3", false)
assert.Equal(t, "some_nonce_to_test", nonce2)
assert.Equal(t, "some_state_to_test", redirect2)
assert.Equal(t, 3, retryCount2)
}
func TestStateDecodesWithColonInRedirect(t *testing.T) {
nonce, redirect, retryCount, _ := decodeState("bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", true)
assert.Equal(t, "nonce_value", nonce)
assert.Equal(t, "test:url", redirect)
assert.Equal(t, 3, retryCount)
nonce2, redirect2, retryCount2, _ := decodeState("nonce_value:test%3Aurl:3", false)
assert.Equal(t, "nonce_value", nonce2)
assert.Equal(t, "test:url", redirect2)
assert.Equal(t, 3, retryCount2)
}
func TestStateDecodesWithoutRetryCount(t *testing.T) {
nonce, redirect, retryCount, _ := decodeState("nonce_value:url", false)
assert.Equal(t, "nonce_value", nonce)
assert.Equal(t, "url", redirect)
assert.Equal(t, 0, retryCount)
}
func TestAuthOnlyAllowedEmails(t *testing.T) {

View File

@ -34,6 +34,9 @@ type Options struct {
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"`
HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"`
HtpasswdUserGroups []string `flag:"htpasswd-user-group" cfg:"htpasswd_user_groups"`
MaxAutomatedRetries int `flag:"max-automated-retries" cfg:"max_automated_retries"`
IdpErrorsToRetry []string `flag:"retry-idp-errors" cfg:"retry_idp_errors"`
RetryCsrfErrors bool `flag:"retry-csrf-errors" cfg:"retry_csrf_errors"`
Cookie Cookie `cfg:",squash"`
Session SessionOptions `cfg:",squash"`
@ -161,6 +164,9 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option")
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
flagSet.Int("max-automated-retries", 0, "Maximum number of automated retries for callback errors")
flagSet.StringSlice("retry-idp-errors", []string{}, "Errors from IdP that should be automatically retried")
flagSet.Bool("retry-csrf-errors", false, "If true retries Csrf errors automatically")
flagSet.AddFlagSet(cookieFlagSet())
flagSet.AddFlagSet(loggingFlagSet())