Merge bad29b556e into 9168731c7a
This commit is contained in:
commit
d08534f8d3
133
oauthproxy.go
133
oauthproxy.go
|
|
@ -13,6 +13,8 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
|
@ -113,7 +115,10 @@ type OAuthProxy struct {
|
|||
redirectValidator redirect.Validator
|
||||
appDirector redirect.AppDirector
|
||||
|
||||
encodeState bool
|
||||
encodeState bool
|
||||
maxAutomatedRetries int
|
||||
idpErrorsToRetry []string
|
||||
retryCsrfErrors bool
|
||||
}
|
||||
|
||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
||||
|
|
@ -238,16 +243,19 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||
allowQuerySemicolons: opts.AllowQuerySemicolons,
|
||||
trustedIPs: trustedIPs,
|
||||
|
||||
basicAuthValidator: basicAuthValidator,
|
||||
basicAuthGroups: opts.HtpasswdUserGroups,
|
||||
sessionChain: sessionChain,
|
||||
headersChain: headersChain,
|
||||
preAuthChain: preAuthChain,
|
||||
pageWriter: pageWriter,
|
||||
upstreamProxy: upstreamProxy,
|
||||
redirectValidator: redirectValidator,
|
||||
appDirector: appDirector,
|
||||
encodeState: opts.EncodeState,
|
||||
basicAuthValidator: basicAuthValidator,
|
||||
basicAuthGroups: opts.HtpasswdUserGroups,
|
||||
sessionChain: sessionChain,
|
||||
headersChain: headersChain,
|
||||
preAuthChain: preAuthChain,
|
||||
pageWriter: pageWriter,
|
||||
upstreamProxy: upstreamProxy,
|
||||
redirectValidator: redirectValidator,
|
||||
appDirector: appDirector,
|
||||
encodeState: opts.EncodeState,
|
||||
maxAutomatedRetries: opts.MaxAutomatedRetries,
|
||||
idpErrorsToRetry: opts.IdpErrorsToRetry,
|
||||
retryCsrfErrors: opts.RetryCsrfErrors,
|
||||
}
|
||||
p.buildServeMux(opts.ProxyPrefix)
|
||||
|
||||
|
|
@ -788,13 +796,32 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// OAuthRestart restarts the OAuth2 authentication flow with a specified retryCounter
|
||||
func (p *OAuthProxy) OAuthRestart(rw http.ResponseWriter, req *http.Request, redirectPath string, retryCount int) (error, string) {
|
||||
if !p.redirectValidator.IsValidRedirect(redirectPath) {
|
||||
return errors.New("invalid redirect"), fmt.Sprintf("Login Failed: The given redirect (%v) was not valid. Please try again.", redirectPath)
|
||||
}
|
||||
|
||||
if retryCount >= p.maxAutomatedRetries {
|
||||
return errors.New("retries exceeded"), fmt.Sprintf("Login Failed: The maximum amount (%v) of automated retries exceeded.", p.maxAutomatedRetries)
|
||||
}
|
||||
|
||||
logger.Printf("Restarting Oauth2 flow (%v/%v)", retryCount, p.maxAutomatedRetries)
|
||||
// Handlers typically use this parameter to preserve the login redirect path.
|
||||
// Since initialization now occurs within the proxy, we must set this parameter
|
||||
// to retain the originally requested path for restarted requests.
|
||||
req.Form.Set("rd", redirectPath)
|
||||
p.doOAuthStart(rw, req, nil, retryCount+1)
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// OAuthStart starts the OAuth2 authentication flow
|
||||
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||
// start the flow permitting login URL query parameters to be overridden from the request URL
|
||||
p.doOAuthStart(rw, req, req.URL.Query())
|
||||
p.doOAuthStart(rw, req, req.URL.Query(), 0)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values) {
|
||||
func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, overrides url.Values, retryCount int) {
|
||||
extraParams := p.provider.Data().LoginURLParams(overrides)
|
||||
prepareNoCache(rw)
|
||||
|
||||
|
|
@ -839,7 +866,7 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove
|
|||
callbackRedirect := p.getOAuthRedirectURI(req)
|
||||
loginURL := p.provider.GetLoginURL(
|
||||
callbackRedirect,
|
||||
encodeState(csrf.HashOAuthState(), appRedirect, p.encodeState),
|
||||
encodeState(csrf.HashOAuthState(), appRedirect, retryCount, p.encodeState),
|
||||
csrf.HashOIDCNonce(),
|
||||
extraParams,
|
||||
)
|
||||
|
|
@ -864,22 +891,33 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
errorString := req.Form.Get("error")
|
||||
if errorString != "" {
|
||||
logger.Errorf("Error while parsing OAuth2 callback: %s", errorString)
|
||||
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
|
||||
// Set the debug message and override the non debug message to be the same for this case
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
|
||||
return
|
||||
}
|
||||
|
||||
nonce, appRedirect, err := decodeState(req.Form.Get("state"), p.encodeState)
|
||||
nonce, appRedirect, retries, err := decodeState(req.Form.Get("state"), p.encodeState)
|
||||
if err != nil {
|
||||
logger.Errorf("Error while parsing OAuth2 state: %v", err)
|
||||
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
errorString := req.Form.Get("error")
|
||||
if errorString != "" {
|
||||
message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString)
|
||||
logger.Error(message)
|
||||
|
||||
if slices.Contains(p.idpErrorsToRetry, errorString) {
|
||||
restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries)
|
||||
|
||||
if restartErr != nil {
|
||||
logger.Errorf("Encountered Error while restarting: %s", restartErr)
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Set the debug message and override the non debug message to be the same for this case
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, message, message)
|
||||
return
|
||||
}
|
||||
|
||||
// calculate the cookie name
|
||||
cookieName := cookies.GenerateCookieName(p.CookieOptions, nonce)
|
||||
// Try to find the CSRF cookie and decode it
|
||||
|
|
@ -888,8 +926,16 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||
// There are a lot of issues opened complaining about missing CSRF cookies.
|
||||
// Try to log the INs and OUTs of OAuthProxy, to be easier to analyse these issues.
|
||||
LoggingCSRFCookiesInOAuthCallback(req, cookieName)
|
||||
logger.Println(req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.")
|
||||
message := fmt.Sprintf("Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce)
|
||||
logger.Println(req, logger.AuthFailure, message)
|
||||
|
||||
restartErr, restartErrString := p.OAuthRestart(rw, req, appRedirect, retries)
|
||||
|
||||
if restartErr != nil {
|
||||
logger.Errorf("Encountered Error while restarting: %s", restartErr)
|
||||
p.ErrorPage(rw, req, http.StatusForbidden, restartErr.Error(), message, restartErrString)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -932,7 +978,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||
logger.Errorf("Error with authorization: %v", err)
|
||||
}
|
||||
if p.Validator(session.Email) && authorized {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2 after %v retries: %s", retries, session)
|
||||
err := p.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Error saving session state for %s: %v", remoteAddr, err)
|
||||
|
|
@ -1029,7 +1075,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
|||
// start OAuth flow, but only with the default login URL params - do not
|
||||
// consider this request's query params as potential overrides, since
|
||||
// the user did not explicitly start the login flow
|
||||
p.doOAuthStart(rw, req, nil)
|
||||
p.doOAuthStart(rw, req, nil, 0)
|
||||
} else {
|
||||
p.SignInPage(rw, req, http.StatusForbidden)
|
||||
}
|
||||
|
|
@ -1242,8 +1288,9 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool {
|
|||
|
||||
// encodeState builds the OAuth state param out of our nonce and
|
||||
// original application redirect
|
||||
func encodeState(nonce string, redirect string, encode bool) string {
|
||||
rawString := fmt.Sprintf("%v:%v", nonce, redirect)
|
||||
func encodeState(nonce string, redirect string, retryCount int, encode bool) string {
|
||||
redirectEscaped := url.QueryEscape(redirect)
|
||||
rawString := fmt.Sprintf("%v:%v:%v", nonce, redirectEscaped, retryCount)
|
||||
if encode {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(rawString))
|
||||
}
|
||||
|
|
@ -1252,18 +1299,36 @@ func encodeState(nonce string, redirect string, encode bool) string {
|
|||
|
||||
// decodeState splits the reflected OAuth state response back into
|
||||
// the nonce and original application redirect
|
||||
func decodeState(state string, encode bool) (string, string, error) {
|
||||
func decodeState(state string, encode bool) (string, string, int, error) {
|
||||
toParse := state
|
||||
if encode {
|
||||
decoded, _ := base64.RawURLEncoding.DecodeString(state)
|
||||
toParse = string(decoded)
|
||||
}
|
||||
|
||||
parsedState := strings.SplitN(toParse, ":", 2)
|
||||
if len(parsedState) != 2 {
|
||||
return "", "", errors.New("invalid length")
|
||||
parsedState := strings.Split(toParse, ":")
|
||||
numStateParams := len(parsedState)
|
||||
|
||||
if numStateParams < 2 || numStateParams > 3 {
|
||||
return "", "", -1, errors.New(fmt.Sprintf("invalid number of state parameters (%v)", numStateParams))
|
||||
}
|
||||
return parsedState[0], parsedState[1], nil
|
||||
|
||||
nonce := parsedState[0]
|
||||
|
||||
redirectPath, err := url.QueryUnescape(parsedState[1])
|
||||
if err != nil {
|
||||
return "", "", -1, errors.New(fmt.Sprintf("invalid redirectPath url (%v)", parsedState[1]))
|
||||
}
|
||||
|
||||
retries := 0
|
||||
if numStateParams > 2 {
|
||||
retries, err = strconv.Atoi(parsedState[2])
|
||||
if err != nil {
|
||||
return "", "", -1, errors.New(fmt.Sprintf("invalid retry count (%v)", parsedState[2]))
|
||||
}
|
||||
}
|
||||
|
||||
return nonce, redirectPath, retries, nil
|
||||
}
|
||||
|
||||
// addHeadersForProxying adds the appropriate headers the request / response for proxying
|
||||
|
|
|
|||
|
|
@ -413,7 +413,7 @@ func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie
|
|||
http.MethodGet,
|
||||
fmt.Sprintf(
|
||||
"/oauth2/callback?code=callback_code&state=%s",
|
||||
encodeState(csrf.HashOAuthState(), "%2F", false),
|
||||
encodeState(csrf.HashOAuthState(), "%2F", 0, false),
|
||||
),
|
||||
strings.NewReader(""),
|
||||
)
|
||||
|
|
@ -3294,24 +3294,61 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) {
|
|||
func TestStateEncodesCorrectly(t *testing.T) {
|
||||
state := "some_state_to_test"
|
||||
nonce := "some_nonce_to_test"
|
||||
retryCount := 3
|
||||
|
||||
encodedResult := encodeState(nonce, state, true)
|
||||
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult)
|
||||
encodedResult := encodeState(nonce, state, retryCount, true)
|
||||
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", encodedResult)
|
||||
|
||||
notEncodedResult := encodeState(nonce, state, false)
|
||||
assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult)
|
||||
notEncodedResult := encodeState(nonce, state, retryCount, false)
|
||||
assert.Equal(t, "some_nonce_to_test:some_state_to_test:3", notEncodedResult)
|
||||
}
|
||||
|
||||
func TestStateEncodesWithColonInRedirect(t *testing.T) {
|
||||
redirect := "test:url"
|
||||
nonce := "nonce_value"
|
||||
retryCount := 3
|
||||
|
||||
encodedResult := encodeState(nonce, redirect, retryCount, true)
|
||||
assert.Equal(t, "bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", encodedResult)
|
||||
|
||||
notEncodedResult := encodeState(nonce, redirect, retryCount, false)
|
||||
assert.Equal(t, "nonce_value:test%3Aurl:3", notEncodedResult)
|
||||
}
|
||||
|
||||
func TestStateDecodesCorrectly(t *testing.T) {
|
||||
nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true)
|
||||
nonce, redirect, retryCount, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdDoz", true)
|
||||
|
||||
assert.Equal(t, "some_nonce_to_test", nonce)
|
||||
assert.Equal(t, "some_state_to_test", redirect)
|
||||
assert.Equal(t, 3, retryCount)
|
||||
|
||||
nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false)
|
||||
nonce2, redirect2, retryCount2, _ := decodeState("some_nonce_to_test:some_state_to_test:3", false)
|
||||
|
||||
assert.Equal(t, "some_nonce_to_test", nonce2)
|
||||
assert.Equal(t, "some_state_to_test", redirect2)
|
||||
assert.Equal(t, 3, retryCount2)
|
||||
}
|
||||
|
||||
func TestStateDecodesWithColonInRedirect(t *testing.T) {
|
||||
nonce, redirect, retryCount, _ := decodeState("bm9uY2VfdmFsdWU6dGVzdCUzQXVybDoz", true)
|
||||
|
||||
assert.Equal(t, "nonce_value", nonce)
|
||||
assert.Equal(t, "test:url", redirect)
|
||||
assert.Equal(t, 3, retryCount)
|
||||
|
||||
nonce2, redirect2, retryCount2, _ := decodeState("nonce_value:test%3Aurl:3", false)
|
||||
|
||||
assert.Equal(t, "nonce_value", nonce2)
|
||||
assert.Equal(t, "test:url", redirect2)
|
||||
assert.Equal(t, 3, retryCount2)
|
||||
}
|
||||
|
||||
func TestStateDecodesWithoutRetryCount(t *testing.T) {
|
||||
nonce, redirect, retryCount, _ := decodeState("nonce_value:url", false)
|
||||
|
||||
assert.Equal(t, "nonce_value", nonce)
|
||||
assert.Equal(t, "url", redirect)
|
||||
assert.Equal(t, 0, retryCount)
|
||||
}
|
||||
|
||||
func TestAuthOnlyAllowedEmails(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ type Options struct {
|
|||
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"`
|
||||
HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"`
|
||||
HtpasswdUserGroups []string `flag:"htpasswd-user-group" cfg:"htpasswd_user_groups"`
|
||||
MaxAutomatedRetries int `flag:"max-automated-retries" cfg:"max_automated_retries"`
|
||||
IdpErrorsToRetry []string `flag:"retry-idp-errors" cfg:"retry_idp_errors"`
|
||||
RetryCsrfErrors bool `flag:"retry-csrf-errors" cfg:"retry_csrf_errors"`
|
||||
|
||||
Cookie Cookie `cfg:",squash"`
|
||||
Session SessionOptions `cfg:",squash"`
|
||||
|
|
@ -161,6 +164,9 @@ func NewFlagSet() *pflag.FlagSet {
|
|||
flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option")
|
||||
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
|
||||
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
|
||||
flagSet.Int("max-automated-retries", 0, "Maximum number of automated retries for callback errors")
|
||||
flagSet.StringSlice("retry-idp-errors", []string{}, "Errors from IdP that should be automatically retried")
|
||||
flagSet.Bool("retry-csrf-errors", false, "If true retries Csrf errors automatically")
|
||||
|
||||
flagSet.AddFlagSet(cookieFlagSet())
|
||||
flagSet.AddFlagSet(loggingFlagSet())
|
||||
|
|
|
|||
Loading…
Reference in New Issue