diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e38f0cd..5d0d8021 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ## Important Notes +- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Redirect URL generation will attempt secondary strategies + in the priority chain if any fail the `IsValidRedirect` security check. Previously any failures fell back to `/`. - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Keycloak will now use `--profile-url` if set for the userinfo endpoint instead of `--validate-url`. `--validate-url` will still work for backwards compatibility. - [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) To use X-Forwarded-{Proto,Host,Uri} on redirect detection, `--reverse-proxy` must be `true`. @@ -36,6 +38,11 @@ ## Breaking Changes +- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) `--reverse-proxy` must be true to trust `X-Forwarded-*` headers as canonical. + These are used throughout the application in redirect URLs, cookie domains and host logging logic. These are the headers: + - `X-Forwarded-Proto` instead of `req.URL.Scheme` + - `X-Forwarded-Host` instead of `req.Host` + - `X-Forwarded-Uri` instead of `req.URL.RequestURI()` - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) In config files & envvar configs, `keycloak_group` is now the plural `keycloak_groups`. Flag configs are still `--keycloak-group` but it can be passed multiple times. - [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google". @@ -60,6 +67,7 @@ ## Changes since v6.1.1 - [#995](https://github.com/oauth2-proxy/oauth2-proxy/pull/995) Add Security Policy (@JoelSpeed) +- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Require `--reverse-proxy` true to trust `X-Forwareded-*` type headers (@NickMeves) - [#970](https://github.com/oauth2-proxy/oauth2-proxy/pull/970) Fix joined cookie name for those containing underline in the suffix (@peppered) - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves) - [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini) diff --git a/oauthproxy.go b/oauthproxy.go index cfba6934..36c58c46 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -24,16 +24,14 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) const ( - httpScheme = "http" - httpsScheme = "https" - + schemeHTTPS = "https" applicationJSON = "application/json" ) @@ -98,7 +96,6 @@ type OAuthProxy struct { SetAuthorization bool PassAuthorization bool PreferEmailToUser bool - ReverseProxy bool skipAuthPreflight bool skipJwtBearerTokens bool templates *template.Template @@ -201,7 +198,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), ProxyPrefix: opts.ProxyPrefix, - ReverseProxy: opts.ReverseProxy, provider: opts.GetProvider(), providerNameOverride: opts.ProviderName, sessionStore: sessionStore, @@ -231,7 +227,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { - chain := alice.New(middleware.NewScope()) + chain := alice.New(middleware.NewScope(opts.ReverseProxy)) if opts.ForceHTTPS { _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) @@ -368,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { return routes, nil } -// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will -// redirect clients to once authenticated -func (p *OAuthProxy) GetRedirectURI(host string) string { - // default to the request Host if not set - if p.redirectURL.Host != "" { - return p.redirectURL.String() - } - u := *p.redirectURL - if u.Scheme == "" { - if p.CookieSecure { - u.Scheme = httpsScheme - } else { - u.Scheme = httpScheme - } - } - u.Host = host - return u.String() -} - -func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { - if code == "" { - return nil, providers.ErrMissingCode - } - redirectURI := p.GetRedirectURI(host) - s, err := p.provider.Redeem(ctx, redirectURI, code) - if err != nil { - return nil, err - } - return s, nil -} - -func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { - var err error - if s.Email == "" { - s.Email, err = p.provider.GetEmailAddress(ctx, s) - if err != nil && !errors.Is(err, providers.ErrNotImplemented) { - return err - } - } - - return p.provider.EnrichSession(ctx, s) -} - // MakeCSRFCookie creates a cookie for CSRF func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) @@ -420,7 +373,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) if cookieDomain != "" { - domain := util.GetRequestHost(req) + domain := requestutil.GetRequestHost(req) if h, _, err := net.SplitHostPort(domain); err == nil { domain = h } @@ -468,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s return p.sessionStore.Save(rw, req, s) } +// IsValidRedirect checks whether the redirect URL is whitelisted +func (p *OAuthProxy) IsValidRedirect(redirect string) bool { + switch { + case redirect == "": + // The user didn't specify a redirect, should fallback to `/` + return false + case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): + return true + case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): + redirectURL, err := url.Parse(redirect) + if err != nil { + logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) + return false + } + redirectHostname := redirectURL.Hostname() + + for _, domain := range p.whitelistDomains { + domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) + if domainHostname == "" { + continue + } + + if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { + // the domain names match, now validate the ports + // if the whitelisted domain's port is '*', allow all ports + // if the whitelisted domain contains a specific port, only allow that port + // if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https + redirectPort := redirectURL.Port() + if (domainPort == "*") || + (domainPort == redirectPort) || + (domainPort == "" && redirectPort == "") { + return true + } + } + } + + logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) + return false + default: + logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) + return false + } +} + +func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) +} + +func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { + if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { + prepareNoCache(rw) + } + + switch path := req.URL.Path; { + case path == p.RobotsPath: + p.RobotsTxt(rw) + case p.IsAllowedRequest(req): + p.SkipAuthProxy(rw, req) + case path == p.SignInPath: + p.SignIn(rw, req) + case path == p.SignOutPath: + p.SignOut(rw, req) + case path == p.OAuthStartPath: + p.OAuthStart(rw, req) + case path == p.OAuthCallbackPath: + p.OAuthCallback(rw, req) + case path == p.AuthOnlyPath: + p.AuthOnly(rw, req) + case path == p.UserInfoPath: + p.UserInfo(rw, req) + default: + p.Proxy(rw, req) + } +} + // RobotsTxt disallows scraping pages from the OAuthProxy func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { _, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") @@ -498,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m } } +// IsAllowedRequest is used to check if auth should be skipped for this request +func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { + isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" + return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) +} + +// IsAllowedRoute is used to check if the request method & path is allowed without auth +func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { + for _, route := range p.allowedRoutes { + if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { + return true + } + } + return false +} + +// isTrustedIP is used to check if a request comes from a trusted client IP address. +func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { + if p.trustedIPs == nil { + return false + } + + remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) + if err != nil { + logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) + // Possibly spoofed X-Real-IP header + return false + } + + if remoteAddr == nil { + return false + } + + return p.trustedIPs.Has(remoteAddr) +} + // SignInPage writes the sing in template to the response func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { prepareNoCache(rw) @@ -509,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code } rw.WriteHeader(code) - redirectURL, err := p.GetRedirect(req) + redirectURL, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -568,220 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) { return "", false } -// GetRedirect reads the query parameter to get the URL to redirect clients to -// once authenticated with the OAuthProxy -func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { - err = req.ParseForm() - if err != nil { - return - } - - redirect = req.Header.Get("X-Auth-Request-Redirect") - if req.Form.Get("rd") != "" { - redirect = req.Form.Get("rd") - } - // Quirk: On reverse proxies that doesn't have support for - // "X-Auth-Request-Redirect" header or dynamic header/query string - // manipulation (like Traefik v1 and v2), we can try if the header - // X-Forwarded-Host exists or not. - if redirect == "" && isForwardedRequest(req, p.ReverseProxy) { - redirect = p.getRedirectFromForwardHeaders(req) - } - if !p.IsValidRedirect(redirect) { - // Use RequestURI to preserve ?query - redirect = req.URL.RequestURI() - - if strings.HasPrefix(redirect, fmt.Sprintf("%s/", p.ProxyPrefix)) { - redirect = "/" - } - } - - return -} - -// getRedirectFromForwardHeaders returns the redirect URL based on X-Forwarded-{Proto,Host,Uri} headers -func (p *OAuthProxy) getRedirectFromForwardHeaders(req *http.Request) string { - uri := util.GetRequestURI(req) - - if strings.HasPrefix(uri, fmt.Sprintf("%s/", p.ProxyPrefix)) { - uri = "/" - } - - return fmt.Sprintf("%s://%s%s", util.GetRequestProto(req), util.GetRequestHost(req), uri) -} - -// splitHostPort separates host and port. If the port is not valid, it returns -// the entire input as host, and it doesn't check the validity of the host. -// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. -// *** taken from net/url, modified validOptionalPort() to accept ":*" -func splitHostPort(hostport string) (host, port string) { - host = hostport - - colon := strings.LastIndexByte(host, ':') - if colon != -1 && validOptionalPort(host[colon:]) { - host, port = host[:colon], host[colon+1:] - } - - if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { - host = host[1 : len(host)-1] - } - - return -} - -// validOptionalPort reports whether port is either an empty string -// or matches /^:\d*$/ -// *** taken from net/url, modified to accept ":*" -func validOptionalPort(port string) bool { - if port == "" || port == ":*" { - return true - } - if port[0] != ':' { - return false - } - for _, b := range port[1:] { - if b < '0' || b > '9' { - return false - } - } - return true -} - -// IsValidRedirect checks whether the redirect URL is whitelisted -func (p *OAuthProxy) IsValidRedirect(redirect string) bool { - switch { - case redirect == "": - // The user didn't specify a redirect, should fallback to `/` - return false - case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): - return true - case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): - redirectURL, err := url.Parse(redirect) - if err != nil { - logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) - return false - } - redirectHostname := redirectURL.Hostname() - - for _, domain := range p.whitelistDomains { - domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) - if domainHostname == "" { - continue - } - - if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { - // the domain names match, now validate the ports - // if the whitelisted domain's port is '*', allow all ports - // if the whitelisted domain contains a specific port, only allow that port - // if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https - redirectPort := redirectURL.Port() - if (domainPort == "*") || - (domainPort == redirectPort) || - (domainPort == "" && redirectPort == "") { - return true - } - } - } - - logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) - return false - default: - logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) - return false - } -} - -// IsAllowedRequest is used to check if auth should be skipped for this request -func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { - isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" - return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req) -} - -// IsAllowedRoute is used to check if the request method & path is allowed without auth -func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { - for _, route := range p.allowedRoutes { - if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { - return true - } - } - return false -} - -// isForwardedRequest is used to check if X-Forwarded-Host header exists or not -func isForwardedRequest(req *http.Request, reverseProxy bool) bool { - isForwarded := req.Host != util.GetRequestHost(req) - return isForwarded && reverseProxy -} - -// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en -var noCacheHeaders = map[string]string{ - "Expires": time.Unix(0, 0).Format(time.RFC1123), - "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", - "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ -} - -// prepareNoCache prepares headers for preventing browser caching. -func prepareNoCache(w http.ResponseWriter) { - // Set NoCache headers - for k, v := range noCacheHeaders { - w.Header().Set(k, v) - } -} - -// IsTrustedIP is used to check if a request comes from a trusted client IP address. -func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool { - if p.trustedIPs == nil { - return false - } - - remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) - if err != nil { - logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) - // Possibly spoofed X-Real-IP header - return false - } - - if remoteAddr == nil { - return false - } - - return p.trustedIPs.Has(remoteAddr) -} - -func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) -} - -func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { - if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { - prepareNoCache(rw) - } - - switch path := req.URL.Path; { - case path == p.RobotsPath: - p.RobotsTxt(rw) - case p.IsAllowedRequest(req): - p.SkipAuthProxy(rw, req) - case path == p.SignInPath: - p.SignIn(rw, req) - case path == p.SignOutPath: - p.SignOut(rw, req) - case path == p.OAuthStartPath: - p.OAuthStart(rw, req) - case path == p.OAuthCallbackPath: - p.OAuthCallback(rw, req) - case path == p.AuthOnlyPath: - p.AuthOnly(rw, req) - case path == p.UserInfoPath: - p.UserInfo(rw, req) - default: - p.Proxy(rw, req) - } -} - // SignIn serves a page prompting users to sign in func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -839,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { // SignOut sends a response to clear the authentication cookie func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -864,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { return } p.SetCSRFCookie(rw, req, nonce) - redirect, err := p.GetRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - redirectURI := p.GetRedirectURI(util.GetRequestHost(req)) + redirectURI := p.getOAuthRedirectURI(req) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) } @@ -893,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code")) + session, err := p.redeemCode(req) if err != nil { logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") @@ -952,6 +805,32 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } } +func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) { + code := req.Form.Get("code") + if code == "" { + return nil, providers.ErrMissingCode + } + + redirectURI := p.getOAuthRedirectURI(req) + s, err := p.provider.Redeem(req.Context(), redirectURI, code) + if err != nil { + return nil, err + } + return s, nil +} + +func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { + var err error + if s.Email == "" { + s.Email, err = p.provider.GetEmailAddress(ctx, s) + if err != nil && !errors.Is(err, providers.ErrNotImplemented) { + return err + } + } + + return p.provider.EnrichSession(ctx, s) +} + // AuthOnly checks whether the user is currently logged in (both authentication // and optional authorization). func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { @@ -969,7 +848,7 @@ func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { } // we are authenticated - p.addHeadersForProxying(rw, req, session) + p.addHeadersForProxying(rw, session) p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusAccepted) })).ServeHTTP(rw, req) @@ -987,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { switch err { case nil: // we are authenticated - p.addHeadersForProxying(rw, req, session) + p.addHeadersForProxying(rw, session) p.headersChain.Then(p.serveMux).ServeHTTP(rw, req) case ErrNeedsLogin: // we need to send the user to a login screen if isAjax(req) { // no point redirecting an AJAX request - p.ErrorJSON(rw, http.StatusUnauthorized) + p.errorJSON(rw, http.StatusUnauthorized) return } @@ -1012,7 +891,195 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { p.ErrorPage(rw, http.StatusInternalServerError, "Internal Error", "Internal Error") } +} +// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en +var noCacheHeaders = map[string]string{ + "Expires": time.Unix(0, 0).Format(time.RFC1123), + "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", + "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ +} + +// prepareNoCache prepares headers for preventing browser caching. +func prepareNoCache(w http.ResponseWriter) { + // Set NoCache headers + for k, v := range noCacheHeaders { + w.Header().Set(k, v) + } +} + +// getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will +// redirect clients to once authenticated. +// This is usually the OAuthProxy callback URL. +func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string { + // if `p.redirectURL` already has a host, return it + if p.redirectURL.Host != "" { + return p.redirectURL.String() + } + + // Otherwise figure out the scheme + host from the request + rd := *p.redirectURL + rd.Host = requestutil.GetRequestHost(req) + rd.Scheme = requestutil.GetRequestProto(req) + + // If CookieSecure is true, return `https` no matter what + // Not all reverse proxies set X-Forwarded-Proto + if p.CookieSecure { + rd.Scheme = schemeHTTPS + } + return rd.String() +} + +// getAppRedirect determines the full URL or URI path to redirect clients to +// once authenticated with the OAuthProxy +// Strategy priority (first legal result is used): +// - `rd` querysting parameter +// - `X-Auth-Request-Redirect` header +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) { + err := req.ParseForm() + if err != nil { + return "", err + } + + // These redirect getter functions are strategies ordered by priority + // for figuring out the redirect URL. + type redirectGetter func(req *http.Request) string + for _, rdGetter := range []redirectGetter{ + p.getRdQuerystringRedirect, + p.getXAuthRequestRedirect, + p.getXForwardedHeadersRedirect, + p.getURIRedirect, + } { + redirect := rdGetter(req) + // Call `p.IsValidRedirect` again here a final time to be safe + if redirect != "" && p.IsValidRedirect(redirect) { + return redirect, nil + } + } + + return "/", nil +} + +func isForwardedRequest(req *http.Request) bool { + return requestutil.IsProxied(req) && + req.Host != requestutil.GetRequestHost(req) +} + +func (p *OAuthProxy) hasProxyPrefix(path string) bool { + return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) +} + +func (p *OAuthProxy) validateRedirect(redirect string, errorFormat string) string { + if p.IsValidRedirect(redirect) { + return redirect + } + if redirect != "" { + logger.Errorf(errorFormat, redirect) + } + return "" +} + +// getRdQuerystringRedirect handles this getAppRedirect strategy: +// - `rd` querysting parameter +func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { + return p.validateRedirect( + req.Form.Get("rd"), + "Invalid redirect provided in rd querystring parameter: %s", + ) +} + +// getXAuthRequestRedirect handles this getAppRedirect strategy: +// - `X-Auth-Request-Redirect` Header +func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { + return p.validateRedirect( + req.Header.Get("X-Auth-Request-Redirect"), + "Invalid redirect provided in X-Auth-Request-Redirect header: %s", + ) +} + +// getXForwardedHeadersRedirect handles these getAppRedirect strategies: +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { + if !isForwardedRequest(req) { + return "" + } + + uri := requestutil.GetRequestURI(req) + if p.hasProxyPrefix(uri) { + uri = "/" + } + + redirect := fmt.Sprintf( + "%s://%s%s", + requestutil.GetRequestProto(req), + requestutil.GetRequestHost(req), + uri, + ) + + return p.validateRedirect(redirect, + "Invalid redirect generated from X-Forwarded-* headers: %s") +} + +// getURIRedirect handles these getAppRedirect strategies: +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) getURIRedirect(req *http.Request) string { + redirect := p.validateRedirect( + requestutil.GetRequestURI(req), + "Invalid redirect generated from X-Forwarded-Uri header: %s", + ) + if redirect == "" { + redirect = req.URL.RequestURI() + } + + if p.hasProxyPrefix(redirect) { + return "/" + } + return redirect +} + +// splitHostPort separates host and port. If the port is not valid, it returns +// the entire input as host, and it doesn't check the validity of the host. +// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. +// *** taken from net/url, modified validOptionalPort() to accept ":*" +func splitHostPort(hostport string) (host, port string) { + host = hostport + + colon := strings.LastIndexByte(host, ':') + if colon != -1 && validOptionalPort(host[colon:]) { + host, port = host[:colon], host[colon+1:] + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return +} + +// validOptionalPort reports whether port is either an empty string +// or matches /^:\d*$/ +// *** taken from net/url, modified to accept ":*" +func validOptionalPort(port string) bool { + if port == "" || port == ":*" { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true } // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so @@ -1024,7 +1091,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R var session *sessionsapi.SessionState getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - session = middleware.GetRequestScope(req).Session + session = middlewareapi.GetRequestScope(req).Session })) getSession.ServeHTTP(rw, req) @@ -1099,7 +1166,7 @@ func extractAllowedGroups(req *http.Request) map[string]struct{} { } // addHeadersForProxying adds the appropriate headers the request / response for proxying -func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) { +func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) { if session.Email == "" { rw.Header().Set("GAP-Auth", session.User) } else { @@ -1127,8 +1194,8 @@ func isAjax(req *http.Request) bool { return false } -// ErrorJSON returns the error code with an application/json mime type -func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { +// errorJSON returns the error code with an application/json mime type +func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) { rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 52ffa2b3..3366ef5f 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -19,6 +19,7 @@ import ( "github.com/coreos/go-oidc" "github.com/mbland/hmacauth" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -414,8 +415,9 @@ func Test_redeemCode(t *testing.T) { t.Fatal(err) } - _, err = proxy.redeemCode(context.Background(), "www.example.com", "") - assert.Error(t, err) + req := httptest.NewRequest(http.MethodGet, "/", nil) + _, err = proxy.redeemCode(req) + assert.Equal(t, providers.ErrMissingCode, err) } func Test_enrichSession(t *testing.T) { @@ -1748,10 +1750,9 @@ func TestRequestSignature(t *testing.T) { } } -func TestGetRedirect(t *testing.T) { +func Test_getAppRedirect(t *testing.T) { opts := baseTestOptions() - opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com") - opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com:8443") + opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443") err := validation.Validate(opts) assert.NoError(t, err) require.NotEmpty(t, opts.ProxyPrefix) @@ -1854,9 +1855,6 @@ func TestGetRedirect(t *testing.T) { url: "https://oauth.example.com/foo/bar", headers: map[string]string{ "X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar", - "X-Forwarded-Proto": "", - "X-Forwarded-Host": "", - "X-Forwarded-Uri": "", }, reverseProxy: true, expectedRedirect: "https://a-service.example.com/foo/bar", @@ -1884,10 +1882,9 @@ func TestGetRedirect(t *testing.T) { name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string", url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz", headers: map[string]string{ - "X-Auth-Request-Redirect": "", - "X-Forwarded-Proto": "https", - "X-Forwarded-Host": "another-service.example.com", - "X-Forwarded-Uri": "/seasons/greetings", + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "another-service.example.com", + "X-Forwarded-Uri": "/seasons/greetings", }, reverseProxy: true, expectedRedirect: "https://a-service.example.com/foo/baz", @@ -1901,8 +1898,10 @@ func TestGetRedirect(t *testing.T) { req.Header.Add(header, value) } } - proxy.ReverseProxy = tt.reverseProxy - redirect, err := proxy.GetRedirect(req) + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: tt.reverseProxy, + }) + redirect, err := proxy.getAppRedirect(req) assert.NoError(t, err) assert.Equal(t, tt.expectedRedirect, redirect) diff --git a/pkg/apis/middleware/middleware_suite_test.go b/pkg/apis/middleware/middleware_suite_test.go new file mode 100644 index 00000000..f2f48cfd --- /dev/null +++ b/pkg/apis/middleware/middleware_suite_test.go @@ -0,0 +1,19 @@ +package middleware_test + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// TestMiddlewareSuite and related tests are in a *_test package +// to prevent circular imports with the `logger` package which uses +// this functionality +func TestMiddlewareSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Middleware API") +} diff --git a/pkg/apis/middleware/scope.go b/pkg/apis/middleware/scope.go index 37f6f336..c54a33d1 100644 --- a/pkg/apis/middleware/scope.go +++ b/pkg/apis/middleware/scope.go @@ -1,13 +1,26 @@ package middleware import ( + "context" + "net/http" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" ) +type scopeKey string + +// RequestScopeKey uses a typed string to reduce likelihood of clashing +// with other context keys +const RequestScopeKey scopeKey = "request-scope" + // RequestScope contains information regarding the request that is being made. // The RequestScope is used to pass information between different middlewares // within the chain. type RequestScope struct { + // ReverseProxy tracks whether OAuth2-Proxy is operating in reverse proxy + // mode and if request `X-Forwarded-*` headers should be trusted + ReverseProxy bool + // Session details the authenticated users information (if it exists). Session *sessions.SessionState @@ -22,3 +35,19 @@ type RequestScope struct { // it was loaded or not. SessionRevalidated bool } + +// GetRequestScope returns the current request scope from the given request +func GetRequestScope(req *http.Request) *RequestScope { + scope := req.Context().Value(RequestScopeKey) + if scope == nil { + return nil + } + + return scope.(*RequestScope) +} + +// AddRequestScope adds a RequestScope to a request +func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request { + ctx := context.WithValue(req.Context(), RequestScopeKey, scope) + return req.WithContext(ctx) +} diff --git a/pkg/apis/middleware/scope_test.go b/pkg/apis/middleware/scope_test.go new file mode 100644 index 00000000..355365bf --- /dev/null +++ b/pkg/apis/middleware/scope_test.go @@ -0,0 +1,56 @@ +package middleware_test + +import ( + "net/http" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Scope Suite", func() { + Context("GetRequestScope", func() { + var request *http.Request + + BeforeEach(func() { + var err error + request, err = http.NewRequest("", "http://127.0.0.1/", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("with a scope", func() { + var scope *middleware.RequestScope + + BeforeEach(func() { + scope = &middleware.RequestScope{} + request = middleware.AddRequestScope(request, scope) + }) + + It("returns the scope", func() { + s := middleware.GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + }) + + Context("if the scope is then modified", func() { + BeforeEach(func() { + Expect(scope.SaveSession).To(BeFalse()) + scope.SaveSession = true + }) + + It("returns the updated session", func() { + s := middleware.GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + Expect(s.SaveSession).To(BeTrue()) + }) + }) + }) + + Context("without a scope", func() { + It("returns nil", func() { + Expect(middleware.GetRequestScope(request)).To(BeNil()) + }) + }) + }) +}) diff --git a/pkg/cookies/cookies.go b/pkg/cookies/cookies.go index 9b6dc03d..c590de38 100644 --- a/pkg/cookies/cookies.go +++ b/pkg/cookies/cookies.go @@ -9,14 +9,14 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" ) // MakeCookie constructs a cookie from the given parameters, // discovering the domain from the request if not specified. func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie { if domain != "" { - host := util.GetRequestHost(req) + host := requestutil.GetRequestHost(req) if h, _, err := net.SplitHostPort(host); err == nil { host = h } @@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO // If nothing matches, create the cookie with the shortest domain defaultDomain := "" if len(cookieOpts.Domains) > 0 { - logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) + logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1] } return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite)) @@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO // GetCookieDomain returns the correct cookie domain given a list of domains // by checking the X-Fowarded-Host and host header of an an http request func GetCookieDomain(req *http.Request, cookieDomains []string) string { - host := util.GetRequestHost(req) + host := requestutil.GetRequestHost(req) for _, domain := range cookieDomains { if strings.HasSuffix(host, domain) { return domain diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 23696765..86ad720e 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -12,7 +12,7 @@ import ( "text/template" "time" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" ) // AuthStatus defines the different types of auth logging that occur @@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu err := l.authTemplate.Execute(l.writer, authLogMessageData{ Client: client, - Host: util.GetRequestHost(req), + Host: requestutil.GetRequestHost(req), Protocol: req.Proto, RequestMethod: req.Method, Timestamp: FormatTimestamp(now), @@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ Client: client, - Host: util.GetRequestHost(req), + Host: requestutil.GetRequestHost(req), Protocol: req.Proto, RequestDuration: fmt.Sprintf("%0.3f", duration), RequestMethod: req.Method, diff --git a/pkg/middleware/basic_session.go b/pkg/middleware/basic_session.go index 5a7b77f9..7de1bf2b 100644 --- a/pkg/middleware/basic_session.go +++ b/pkg/middleware/basic_session.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor { // If a session was loaded by a previous handler, it will not be replaced. func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/basic_session_test.go b/pkg/middleware/basic_session_test.go index 35e4f804..14c49c43 100644 --- a/pkg/middleware/basic_session_test.go +++ b/pkg/middleware/basic_session_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "fmt" "net/http" "net/http/httptest" @@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() { // Set up the request with the authorization header and a request scope req := httptest.NewRequest("", "/", nil) req.Header.Set("Authorization", in.authorizationHeader) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() { // from the scope var gotSession *sessionsapi.SessionState handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/middleware/headers.go b/pkg/middleware/headers.go index 6786c2eb..b79b547b 100644 --- a/pkg/middleware/headers.go +++ b/pkg/middleware/headers.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header" ) @@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. @@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. diff --git a/pkg/middleware/headers_test.go b/pkg/middleware/headers_test.go index 15006b1d..a9c6d73e 100644 --- a/pkg/middleware/headers_test.go +++ b/pkg/middleware/headers_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "encoding/base64" "net/http" "net/http/httptest" @@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() { // Set up the request with a request scope req := httptest.NewRequest("", "/", nil) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) req.Header = in.initialHeaders.Clone() rw := httptest.NewRecorder() @@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() { // Set up the request with a request scope req := httptest.NewRequest("", "/", nil) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() for key, values := range in.initialHeaders { diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go index 0510c72a..78ef5400 100644 --- a/pkg/middleware/jwt_session.go +++ b/pkg/middleware/jwt_session.go @@ -37,7 +37,7 @@ type jwtSessionLoader struct { // If a session was loaded by a previous handler, it will not be replaced. func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/jwt_session_test.go b/pkg/middleware/jwt_session_test.go index cd34c5ad..7786d00a 100644 --- a/pkg/middleware/jwt_session_test.go +++ b/pkg/middleware/jwt_session_test.go @@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` // Set up the request with the authorization header and a request scope req := httptest.NewRequest("", "/", nil) req.Header.Set("Authorization", in.authorizationHeader) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` // from the scope var gotSession *sessionsapi.SessionState handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/middleware/redirect_to_https.go b/pkg/middleware/redirect_to_https.go index 18b4b967..72f9dac4 100644 --- a/pkg/middleware/redirect_to_https.go +++ b/pkg/middleware/redirect_to_https.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/justinas/alice" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" ) const httpsScheme = "https" @@ -26,10 +26,11 @@ func NewRedirectToHTTPS(httpsPort string) alice.Constructor { // to the port from the httpsAddress given. func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - proto := req.Header.Get("X-Forwarded-Proto") - if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") { - // Only care about the connection to us being HTTPS if the proto is empty, - // otherwise the proto is source of truth + proto := requestutil.GetRequestProto(req) + if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == req.URL.Scheme) { + // Only care about the connection to us being HTTPS if the proto wasn't + // from a trusted `X-Forwarded-Proto` (proto == req.URL.Scheme). + // Otherwise the proto is source of truth next.ServeHTTP(rw, req) return } @@ -41,7 +42,7 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { // Set the Host in case the targetURL still does not have one // or it isn't X-Forwarded-Host aware - targetURL.Host = util.GetRequestHost(req) + targetURL.Host = requestutil.GetRequestHost(req) // Overwrite the port if the original request was to a non-standard port if targetURL.Port() != "" { diff --git a/pkg/middleware/redirect_to_https_test.go b/pkg/middleware/redirect_to_https_test.go index ca8bdb99..f8c2c6bb 100644 --- a/pkg/middleware/redirect_to_https_test.go +++ b/pkg/middleware/redirect_to_https_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http/httptest" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -21,6 +22,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString string useTLS bool headers map[string]string + reverseProxy bool expectedStatus int expectedBody string expectedLocation string @@ -35,6 +37,10 @@ var _ = Describe("RedirectToHTTPS suite", func() { if in.useTLS { req.TLS = &tls.ConnectionState{} } + scope := &middlewareapi.RequestScope{ + ReverseProxy: in.reverseProxy, + } + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -52,6 +58,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "http://example.com", useTLS: false, headers: map[string]string{}, + reverseProxy: false, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -60,6 +67,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "https://example.com", useTLS: true, headers: map[string]string{}, + reverseProxy: false, expectedStatus: 200, expectedBody: "test", }), @@ -69,15 +77,28 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "HTTPS", }, + reverseProxy: true, expectedStatus: 200, expectedBody: "test", }), + Entry("without TLS and X-Forwarded-Proto=HTTPS but ReverseProxy not set", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTPS", + }, + reverseProxy: false, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{ requestString: "https://example.com", useTLS: true, headers: map[string]string{ "X-Forwarded-Proto": "HTTPS", }, + reverseProxy: true, expectedStatus: 200, expectedBody: "test", }), @@ -87,6 +108,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "https", }, + reverseProxy: true, expectedStatus: 200, expectedBody: "test", }), @@ -96,6 +118,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "https", }, + reverseProxy: true, expectedStatus: 200, expectedBody: "test", }), @@ -105,6 +128,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "HTTP", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -115,6 +139,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "HTTP", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -125,6 +150,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "http", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -135,6 +161,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "http", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -143,6 +170,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "http://example.com:8080", useTLS: false, headers: map[string]string{}, + reverseProxy: false, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com:8443"), expectedLocation: "https://example.com:8443", @@ -151,6 +179,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "https://example.com:8443", useTLS: true, headers: map[string]string{}, + reverseProxy: false, expectedStatus: 200, expectedBody: "test", }), @@ -161,6 +190,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "/", useTLS: false, expectedStatus: 308, + reverseProxy: false, expectedBody: permanentRedirectBody("https://example.com/"), expectedLocation: "https://example.com/", }), @@ -171,6 +201,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { "X-Forwarded-Proto": "HTTP", "X-Forwarded-Host": "external.example.com", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://external.example.com"), expectedLocation: "https://external.example.com", diff --git a/pkg/middleware/scope.go b/pkg/middleware/scope.go index 88719310..9218faa0 100644 --- a/pkg/middleware/scope.go +++ b/pkg/middleware/scope.go @@ -1,39 +1,20 @@ package middleware import ( - "context" "net/http" "github.com/justinas/alice" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" ) -type scopeKey string - -// requestScopeKey uses a typed string to reduce likelihood of clasing -// with other context keys -const requestScopeKey scopeKey = "request-scope" - -func NewScope() alice.Constructor { - return addScope -} - -// addScope injects a new request scope into the request context. -func addScope(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := &middlewareapi.RequestScope{} - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - requestWithScope := req.WithContext(contextWithScope) - next.ServeHTTP(rw, requestWithScope) - }) -} - -// GetRequestScope returns the current request scope from the given request -func GetRequestScope(req *http.Request) *middlewareapi.RequestScope { - scope := req.Context().Value(requestScopeKey) - if scope == nil { - return nil +func NewScope(reverseProxy bool) alice.Constructor { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := &middlewareapi.RequestScope{ + ReverseProxy: reverseProxy, + } + req = middlewareapi.AddRequestScope(req, scope) + next.ServeHTTP(rw, req) + }) } - - return scope.(*middlewareapi.RequestScope) } diff --git a/pkg/middleware/scope_test.go b/pkg/middleware/scope_test.go index e9533a8d..3432d148 100644 --- a/pkg/middleware/scope_test.go +++ b/pkg/middleware/scope_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "net/http" "net/http/httptest" @@ -21,73 +20,49 @@ var _ = Describe("Scope Suite", func() { Expect(err).ToNot(HaveOccurred()) rw = httptest.NewRecorder() - - handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - nextRequest = r - w.WriteHeader(200) - })) - handler.ServeHTTP(rw, request) }) - It("does not add a scope to the original request", func() { - Expect(request.Context().Value(requestScopeKey)).To(BeNil()) - }) - - It("cannot load a scope from the original request using GetRequestScope", func() { - Expect(GetRequestScope(request)).To(BeNil()) - }) - - It("adds a scope to the request for the next handler", func() { - Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil()) - }) - - It("can load a scope from the next handler's request using GetRequestScope", func() { - Expect(GetRequestScope(nextRequest)).ToNot(BeNil()) - }) - }) - - Context("GetRequestScope", func() { - var request *http.Request - - BeforeEach(func() { - var err error - request, err = http.NewRequest("", "http://127.0.0.1/", nil) - Expect(err).ToNot(HaveOccurred()) - }) - - Context("with a scope", func() { - var scope *middlewareapi.RequestScope - + Context("ReverseProxy is false", func() { BeforeEach(func() { - scope = &middlewareapi.RequestScope{} - contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope) - request = request.WithContext(contextWithScope) + handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextRequest = r + w.WriteHeader(200) + })) + handler.ServeHTTP(rw, request) }) - It("returns the scope", func() { - s := GetRequestScope(request) - Expect(s).ToNot(BeNil()) - Expect(s).To(Equal(scope)) + It("does not add a scope to the original request", func() { + Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil()) }) - Context("if the scope is then modified", func() { - BeforeEach(func() { - Expect(scope.SaveSession).To(BeFalse()) - scope.SaveSession = true - }) + It("cannot load a scope from the original request using GetRequestScope", func() { + Expect(middlewareapi.GetRequestScope(request)).To(BeNil()) + }) - It("returns the updated session", func() { - s := GetRequestScope(request) - Expect(s).ToNot(BeNil()) - Expect(s).To(Equal(scope)) - Expect(s.SaveSession).To(BeTrue()) - }) + It("adds a scope to the request for the next handler", func() { + Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil()) + }) + + It("can load a scope from the next handler's request using GetRequestScope", func() { + scope := middlewareapi.GetRequestScope(nextRequest) + Expect(scope).ToNot(BeNil()) + Expect(scope.ReverseProxy).To(BeFalse()) }) }) - Context("without a scope", func() { - It("returns nil", func() { - Expect(GetRequestScope(request)).To(BeNil()) + Context("ReverseProxy is true", func() { + BeforeEach(func() { + handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextRequest = r + w.WriteHeader(200) + })) + handler.ServeHTTP(rw, request) + }) + + It("return a scope where the ReverseProxy field is true", func() { + scope := middlewareapi.GetRequestScope(nextRequest) + Expect(scope).ToNot(BeNil()) + Expect(scope.ReverseProxy).To(BeTrue()) }) }) }) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 6d86e613..1bd0a9a4 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -8,6 +8,7 @@ import ( "time" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) @@ -59,7 +60,7 @@ type storedSessionLoader struct { // If a session was loader by a previous handler, it will not be replaced. func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 89eadc5d..4a8fd9da 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() { // Set up the request with the request headesr and a request scope req := httptest.NewRequest("", "/", nil) req.Header = in.requestHeaders - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() { // from the scope var gotSession *sessionsapi.SessionState handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/requests/util/util.go b/pkg/requests/util/util.go new file mode 100644 index 00000000..08c9c2c1 --- /dev/null +++ b/pkg/requests/util/util.go @@ -0,0 +1,48 @@ +package util + +import ( + "net/http" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" +) + +// GetRequestProto returns the request scheme or X-Forwarded-Proto if present +// and the request is proxied. +func GetRequestProto(req *http.Request) string { + proto := req.Header.Get("X-Forwarded-Proto") + if !IsProxied(req) || proto == "" { + proto = req.URL.Scheme + } + return proto +} + +// GetRequestHost returns the request host header or X-Forwarded-Host if +// present and the request is proxied. +func GetRequestHost(req *http.Request) string { + host := req.Header.Get("X-Forwarded-Host") + if !IsProxied(req) || host == "" { + host = req.Host + } + return host +} + +// GetRequestURI return the request URI or X-Forwarded-Uri if present and the +// request is proxied. +func GetRequestURI(req *http.Request) string { + uri := req.Header.Get("X-Forwarded-Uri") + if !IsProxied(req) || uri == "" { + // Use RequestURI to preserve ?query + uri = req.URL.RequestURI() + } + return uri +} + +// IsProxied determines if a request was from a proxy based on the RequestScope +// ReverseProxy tracker. +func IsProxied(req *http.Request) bool { + scope := middlewareapi.GetRequestScope(req) + if scope == nil { + return false + } + return scope.ReverseProxy +} diff --git a/pkg/requests/util/util_suite_test.go b/pkg/requests/util/util_suite_test.go new file mode 100644 index 00000000..a03f943f --- /dev/null +++ b/pkg/requests/util/util_suite_test.go @@ -0,0 +1,19 @@ +package util_test + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// TestRequestUtilSuite and related tests are in a *_test package +// to prevent circular imports with the `logger` package which uses +// this functionality +func TestRequestUtilSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Request Utils") +} diff --git a/pkg/requests/util/util_test.go b/pkg/requests/util/util_test.go new file mode 100644 index 00000000..595f93f6 --- /dev/null +++ b/pkg/requests/util/util_test.go @@ -0,0 +1,131 @@ +package util_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Util Suite", func() { + const ( + proto = "http" + host = "www.oauth2proxy.test" + uri = "/test/endpoint" + ) + var req *http.Request + + BeforeEach(func() { + req = httptest.NewRequest( + http.MethodGet, + fmt.Sprintf("%s://%s%s", proto, host, uri), + nil, + ) + }) + + Context("GetRequestHost", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the host", func() { + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + + It("ignores X-Forwarded-Host and returns the host", func() { + req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text") + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the host if X-Forwarded-Host is not present", func() { + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + + It("returns the X-Forwarded-Host when present", func() { + req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text") + Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text")) + }) + }) + }) + + Context("GetRequestProto", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the scheme", func() { + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + + It("ignores X-Forwarded-Proto and returns the scheme", func() { + req.Header.Add("X-Forwarded-Proto", "https") + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the scheme if X-Forwarded-Proto is not present", func() { + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + + It("returns the X-Forwarded-Proto when present", func() { + req.Header.Add("X-Forwarded-Proto", "https") + Expect(util.GetRequestProto(req)).To(Equal("https")) + }) + }) + }) + + Context("GetRequestURI", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the URI", func() { + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + + It("ignores X-Forwarded-Uri and returns the URI", func() { + req.Header.Add("X-Forwarded-Uri", "/some/other/path") + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the URI if X-Forwarded-Uri is not present", func() { + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + + It("returns the X-Forwarded-Uri when present", func() { + req.Header.Add("X-Forwarded-Uri", "/some/other/path") + Expect(util.GetRequestURI(req)).To(Equal("/some/other/path")) + }) + }) + }) +}) diff --git a/pkg/util/util.go b/pkg/util/util.go index 4eeabbf7..4519fdb8 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -4,7 +4,6 @@ import ( "crypto/x509" "fmt" "io/ioutil" - "net/http" ) func GetCertPool(paths []string) (*x509.CertPool, error) { @@ -24,31 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) { } return pool, nil } - -// GetRequestProto return the request host header or X-Forwarded-Proto if present -func GetRequestProto(req *http.Request) string { - proto := req.Header.Get("X-Forwarded-Proto") - if proto == "" { - proto = req.URL.Scheme - } - return proto -} - -// GetRequestHost return the request host header or X-Forwarded-Host if present -func GetRequestHost(req *http.Request) string { - host := req.Header.Get("X-Forwarded-Host") - if host == "" { - host = req.Host - } - return host -} - -// GetRequestURI return the request host header or X-Forwarded-Uri if present -func GetRequestURI(req *http.Request) string { - uri := req.Header.Get("X-Forwarded-Uri") - if uri == "" { - // Use RequestURI to preserve ?query - uri = req.URL.RequestURI() - } - return uri -} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index d032025e..347f41bb 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -4,11 +4,9 @@ import ( "crypto/x509/pkix" "encoding/asn1" "io/ioutil" - "net/http/httptest" "os" "testing" - . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) @@ -97,42 +95,3 @@ func TestGetCertPool(t *testing.T) { expectedSubjects := []string{testCA1Subj, testCA2Subj} assert.Equal(t, expectedSubjects, got) } - -func TestGetRequestHost(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com", nil) - host := GetRequestHost(req) - g.Expect(host).To(Equal("example.com")) - - proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil) - proxyReq.Header.Add("X-Forwarded-Host", "external.example.com") - extHost := GetRequestHost(proxyReq) - g.Expect(extHost).To(Equal("external.example.com")) -} - -func TestGetRequestProto(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com", nil) - proto := GetRequestProto(req) - g.Expect(proto).To(Equal("https")) - - proxyReq := httptest.NewRequest("GET", "https://internal.example.com", nil) - proxyReq.Header.Add("X-Forwarded-Proto", "http") - extProto := GetRequestProto(proxyReq) - g.Expect(extProto).To(Equal("http")) -} - -func TestGetRequestURI(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com/ping", nil) - uri := GetRequestURI(req) - g.Expect(uri).To(Equal("/ping")) - - proxyReq := httptest.NewRequest("GET", "http://internal.example.com/bong", nil) - proxyReq.Header.Add("X-Forwarded-Uri", "/ping") - extURI := GetRequestURI(proxyReq) - g.Expect(extURI).To(Equal("/ping")) -}