diff --git a/CHANGELOG.md b/CHANGELOG.md index 00dc6204..dbec2b12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ ## Changes since v6.1.0 +- [#729](https://github.com/oauth2-proxy/oauth2-proxy/pull/729) Use X-Forwarded-Host consistently when set (@NickMeves) + # v6.1.0 ## Release Highlights diff --git a/oauthproxy.go b/oauthproxy.go index 5bfb7631..f4d3c496 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -28,6 +28,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" + "github.com/oauth2-proxy/oauth2-proxy/pkg/util" "github.com/oauth2-proxy/oauth2-proxy/providers" ) @@ -332,7 +333,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) if cookieDomain != "" { - domain := cookies.GetRequestHost(req) + domain := util.GetRequestHost(req) if h, _, err := net.SplitHostPort(domain); err == nil { domain = h } @@ -747,7 +748,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - redirectURI := p.GetRedirectURI(req.Host) + redirectURI := p.GetRedirectURI(util.GetRequestHost(req)) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) } @@ -770,7 +771,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code")) + session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code")) if err != nil { logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") diff --git a/pkg/cookies/cookies.go b/pkg/cookies/cookies.go index 6967db9f..0d499639 100644 --- a/pkg/cookies/cookies.go +++ b/pkg/cookies/cookies.go @@ -9,13 +9,14 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/util" ) // MakeCookie constructs a cookie from the given parameters, // discovering the domain from the request if not specified. func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie { if domain != "" { - host := req.Host + host := util.GetRequestHost(req) if h, _, err := net.SplitHostPort(host); err == nil { host = h } @@ -47,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO // If nothing matches, create the cookie with the shortest domain defaultDomain := "" if len(cookieOpts.Domains) > 0 { - logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) + logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1] } return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite)) @@ -56,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO // GetCookieDomain returns the correct cookie domain given a list of domains // by checking the X-Fowarded-Host and host header of an an http request func GetCookieDomain(req *http.Request, cookieDomains []string) string { - host := GetRequestHost(req) + host := util.GetRequestHost(req) for _, domain := range cookieDomains { if strings.HasSuffix(host, domain) { return domain @@ -65,15 +66,6 @@ func GetCookieDomain(req *http.Request, cookieDomains []string) string { return "" } -// GetRequestHost return the request host header or X-Forwarded-Host if present -func GetRequestHost(req *http.Request) string { - host := req.Header.Get("X-Forwarded-Host") - if host == "" { - host = req.Host - } - return host -} - // Parse a valid http.SameSite value from a user supplied string for use of making cookies. func ParseSameSite(v string) http.SameSite { switch v { diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index d5aab57c..901d0a0d 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -11,6 +11,8 @@ import ( "sync" "text/template" "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/util" ) // AuthStatus defines the different types of auth logging that occur @@ -195,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu err := l.authTemplate.Execute(l.writer, authLogMessageData{ Client: client, - Host: req.Host, + Host: util.GetRequestHost(req), Protocol: req.Proto, RequestMethod: req.Method, Timestamp: FormatTimestamp(now), @@ -249,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ Client: client, - Host: req.Host, + Host: util.GetRequestHost(req), Protocol: req.Proto, RequestDuration: fmt.Sprintf("%0.3f", duration), RequestMethod: req.Method, diff --git a/pkg/middleware/redirect_to_https.go b/pkg/middleware/redirect_to_https.go index 27384561..691a0565 100644 --- a/pkg/middleware/redirect_to_https.go +++ b/pkg/middleware/redirect_to_https.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/justinas/alice" + "github.com/oauth2-proxy/oauth2-proxy/pkg/util" ) const httpsScheme = "https" @@ -38,10 +39,9 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { // Set the scheme to HTTPS targetURL.Scheme = httpsScheme - // Set the req.Host when the targetURL still does not have one - if targetURL.Host == "" { - targetURL.Host = req.Host - } + // Set the Host in case the targetURL still does not have one + // or it isn't X-Forwarded-Host aware + targetURL.Host = util.GetRequestHost(req) // Overwrite the port if the original request was to a non-standard port if targetURL.Port() != "" { diff --git a/pkg/middleware/redirect_to_https_test.go b/pkg/middleware/redirect_to_https_test.go index 238c3b81..ca8bdb99 100644 --- a/pkg/middleware/redirect_to_https_test.go +++ b/pkg/middleware/redirect_to_https_test.go @@ -164,5 +164,16 @@ var _ = Describe("RedirectToHTTPS suite", func() { expectedBody: permanentRedirectBody("https://example.com/"), expectedLocation: "https://example.com/", }), + Entry("without TLS with an X-Forwarded-Host header", &requestTableInput{ + requestString: "http://internal.example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTP", + "X-Forwarded-Host": "external.example.com", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://external.example.com"), + expectedLocation: "https://external.example.com", + }), ) }) diff --git a/pkg/util/util.go b/pkg/util/util.go index 4519fdb8..b39c1032 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -4,6 +4,7 @@ import ( "crypto/x509" "fmt" "io/ioutil" + "net/http" ) func GetCertPool(paths []string) (*x509.CertPool, error) { @@ -23,3 +24,12 @@ func GetCertPool(paths []string) (*x509.CertPool, error) { } return pool, nil } + +// GetRequestHost return the request host header or X-Forwarded-Host if present +func GetRequestHost(req *http.Request) string { + host := req.Header.Get("X-Forwarded-Host") + if host == "" { + host = req.Host + } + return host +} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 63816209..c1e3d688 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -4,9 +4,11 @@ import ( "crypto/x509/pkix" "encoding/asn1" "io/ioutil" + "net/http/httptest" "os" "testing" + . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) @@ -70,7 +72,13 @@ func TestGetCertPool_NoRoots(t *testing.T) { func TestGetCertPool(t *testing.T) { tempDir, err := ioutil.TempDir("", "certtest") assert.NoError(t, err) - defer os.RemoveAll(tempDir) + defer func(path string) { + rerr := os.RemoveAll(path) + if rerr != nil { + panic(rerr) + } + }(tempDir) + certFile1 := makeTestCertFile(t, testCA1, tempDir) certFile2 := makeTestCertFile(t, testCA2, tempDir) @@ -89,3 +97,16 @@ func TestGetCertPool(t *testing.T) { expectedSubjects := []string{testCA1Subj, testCA2Subj} assert.Equal(t, expectedSubjects, got) } + +func TestGetRequestHost(t *testing.T) { + g := NewWithT(t) + + req := httptest.NewRequest("GET", "https://example.com", nil) + host := GetRequestHost(req) + g.Expect(host).To(Equal("example.com")) + + proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil) + proxyReq.Header.Add("X-Forwarded-Host", "external.example.com") + extHost := GetRequestHost(proxyReq) + g.Expect(extHost).To(Equal("external.example.com")) +}