diff --git a/oauthproxy.go b/oauthproxy.go index b51b9bea..a595bc3b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -24,9 +24,9 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) @@ -98,7 +98,6 @@ type OAuthProxy struct { SetAuthorization bool PassAuthorization bool PreferEmailToUser bool - ReverseProxy bool skipAuthPreflight bool skipJwtBearerTokens bool templates *template.Template @@ -201,7 +200,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), ProxyPrefix: opts.ProxyPrefix, - ReverseProxy: opts.ReverseProxy, provider: opts.GetProvider(), providerNameOverride: opts.ProviderName, sessionStore: sessionStore, @@ -231,7 +229,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { - chain := alice.New(middleware.NewScope(opts)) + chain := alice.New(middleware.NewScope(opts.ReverseProxy)) if opts.ForceHTTPS { _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) @@ -368,9 +366,9 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { return routes, nil } -// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will +// GetOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will // redirect clients to once authenticated -func (p *OAuthProxy) GetRedirectURI(host string) string { +func (p *OAuthProxy) GetOAuthRedirectURI(host string) string { // default to the request Host if not set if p.redirectURL.Host != "" { return p.redirectURL.String() @@ -391,7 +389,7 @@ func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessio if code == "" { return nil, providers.ErrMissingCode } - redirectURI := p.GetRedirectURI(host) + redirectURI := p.GetOAuthRedirectURI(host) s, err := p.provider.Redeem(ctx, redirectURI, code) if err != nil { return nil, err @@ -420,7 +418,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) if cookieDomain != "" { - domain := util.GetRequestHost(req) + domain := requestutil.GetRequestHost(req) if h, _, err := net.SplitHostPort(domain); err == nil { domain = h } @@ -509,7 +507,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code } rw.WriteHeader(code) - redirectURL, err := p.GetRedirect(req) + redirectURL, err := p.GetAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -568,46 +566,108 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) { return "", false } -// GetRedirect reads the query parameter to get the URL to redirect clients to +// GetAppRedirect determines the full URL or URI path to redirect clients to // once authenticated with the OAuthProxy -func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { - err = req.ParseForm() +// Strategy priority (first legal result is used): +// - `rd` querysting parameter +// - `X-Auth-Request-Redirect` header +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) GetAppRedirect(req *http.Request) (string, error) { + err := req.ParseForm() if err != nil { - return + return "", err } - redirect = req.Header.Get("X-Auth-Request-Redirect") - if req.Form.Get("rd") != "" { - redirect = req.Form.Get("rd") - } - // Quirk: On reverse proxies that doesn't have support for - // "X-Auth-Request-Redirect" header or dynamic header/query string - // manipulation (like Traefik v1 and v2), we can try if the header - // X-Forwarded-Host exists or not. - if redirect == "" && isForwardedRequest(req, p.ReverseProxy) { - redirect = p.getRedirectFromForwardHeaders(req) - } - if !p.IsValidRedirect(redirect) { - // Use RequestURI to preserve ?query - redirect = req.URL.RequestURI() - - if strings.HasPrefix(redirect, fmt.Sprintf("%s/", p.ProxyPrefix)) { - redirect = "/" + // These redirect getter functions are strategies ordered by priority + // for figuring out the redirect URL. + type redirectGetter func(req *http.Request) string + for _, rdGetter := range []redirectGetter{ + p.getRdQuerystringRedirect, + p.getXAuthRequestRedirect, + p.getXForwardedHeadersRedirect, + p.getURIRedirect, + } { + if redirect := rdGetter(req); redirect != "" { + return redirect, nil } } - return + return "/", nil } -// getRedirectFromForwardHeaders returns the redirect URL based on X-Forwarded-{Proto,Host,Uri} headers -func (p *OAuthProxy) getRedirectFromForwardHeaders(req *http.Request) string { - uri := util.GetRequestURI(req) +func isForwardedRequest(req *http.Request) bool { + return requestutil.IsProxied(req) && + req.Host != requestutil.GetRequestHost(req) +} - if strings.HasPrefix(uri, fmt.Sprintf("%s/", p.ProxyPrefix)) { +func (p *OAuthProxy) hasProxyPrefix(path string) bool { + return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) +} + +// getRdQuerystringRedirect handles this GetAppRedirect strategy: +// - `rd` querysting parameter +func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { + redirect := req.Form.Get("rd") + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getXAuthRequestRedirect handles this GetAppRedirect strategy: +// - `X-Auth-Request-Redirect` Header +func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { + redirect := req.Header.Get("X-Auth-Request-Redirect") + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getXForwardedHeadersRedirect handles these GetAppRedirect strategies: +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { + if !isForwardedRequest(req) { + return "" + } + + uri := requestutil.GetRequestURI(req) + if p.hasProxyPrefix(uri) { uri = "/" } - return fmt.Sprintf("%s://%s%s", util.GetRequestProto(req), util.GetRequestHost(req), uri) + redirect := fmt.Sprintf( + "%s://%s%s", + requestutil.GetRequestProto(req), + requestutil.GetRequestHost(req), + uri, + ) + + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getURIRedirect handles these GetAppRedirect strategies: +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) getURIRedirect(req *http.Request) string { + redirect := requestutil.GetRequestURI(req) + if !p.IsValidRedirect(redirect) { + redirect = req.URL.RequestURI() + } + + if p.hasProxyPrefix(redirect) { + return "/" + } + return redirect } // splitHostPort separates host and port. If the port is not valid, it returns @@ -707,12 +767,6 @@ func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { return false } -// isForwardedRequest is used to check if X-Forwarded-Host header exists or not -func isForwardedRequest(req *http.Request, reverseProxy bool) bool { - isForwarded := req.Host != util.GetRequestHost(req) - return isForwarded && reverseProxy -} - // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en var noCacheHeaders = map[string]string{ "Expires": time.Unix(0, 0).Format(time.RFC1123), @@ -781,7 +835,7 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { // SignIn serves a page prompting users to sign in func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetRedirect(req) + redirect, err := p.GetAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -839,7 +893,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { // SignOut sends a response to clear the authentication cookie func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetRedirect(req) + redirect, err := p.GetAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -864,13 +918,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { return } p.SetCSRFCookie(rw, req, nonce) - redirect, err := p.GetRedirect(req) + redirect, err := p.GetAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - redirectURI := p.GetRedirectURI(util.GetRequestHost(req)) + redirectURI := p.GetOAuthRedirectURI(requestutil.GetRequestHost(req)) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) } @@ -893,7 +947,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code")) + session, err := p.redeemCode(req.Context(), requestutil.GetRequestHost(req), req.Form.Get("code")) if err != nil { logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") @@ -1024,7 +1078,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R var session *sessionsapi.SessionState getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - session = middleware.GetRequestScope(req).Session + session = middlewareapi.GetRequestScope(req).Session })) getSession.ServeHTTP(rw, req) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 52ffa2b3..8adea1ce 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -19,6 +19,7 @@ import ( "github.com/coreos/go-oidc" "github.com/mbland/hmacauth" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -1750,8 +1751,7 @@ func TestRequestSignature(t *testing.T) { func TestGetRedirect(t *testing.T) { opts := baseTestOptions() - opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com") - opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com:8443") + opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443") err := validation.Validate(opts) assert.NoError(t, err) require.NotEmpty(t, opts.ProxyPrefix) @@ -1854,9 +1854,6 @@ func TestGetRedirect(t *testing.T) { url: "https://oauth.example.com/foo/bar", headers: map[string]string{ "X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar", - "X-Forwarded-Proto": "", - "X-Forwarded-Host": "", - "X-Forwarded-Uri": "", }, reverseProxy: true, expectedRedirect: "https://a-service.example.com/foo/bar", @@ -1884,10 +1881,9 @@ func TestGetRedirect(t *testing.T) { name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string", url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz", headers: map[string]string{ - "X-Auth-Request-Redirect": "", - "X-Forwarded-Proto": "https", - "X-Forwarded-Host": "another-service.example.com", - "X-Forwarded-Uri": "/seasons/greetings", + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "another-service.example.com", + "X-Forwarded-Uri": "/seasons/greetings", }, reverseProxy: true, expectedRedirect: "https://a-service.example.com/foo/baz", @@ -1901,8 +1897,10 @@ func TestGetRedirect(t *testing.T) { req.Header.Add(header, value) } } - proxy.ReverseProxy = tt.reverseProxy - redirect, err := proxy.GetRedirect(req) + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: tt.reverseProxy, + }) + redirect, err := proxy.GetAppRedirect(req) assert.NoError(t, err) assert.Equal(t, tt.expectedRedirect, redirect)