Merge pull request #729 from grnhse/x-forwarded-host-redirect
Use X-Forwarded-Host in Redirects
This commit is contained in:
		
						commit
						73f0094486
					
				|  | @ -8,6 +8,8 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v6.1.0 | ## Changes since v6.1.0 | ||||||
| 
 | 
 | ||||||
|  | - [#729](https://github.com/oauth2-proxy/oauth2-proxy/pull/729) Use X-Forwarded-Host consistently when set (@NickMeves) | ||||||
|  | 
 | ||||||
| # v6.1.0 | # v6.1.0 | ||||||
| 
 | 
 | ||||||
| ## Release Highlights | ## Release Highlights | ||||||
|  |  | ||||||
|  | @ -28,6 +28,7 @@ import ( | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/util" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/providers" | 	"github.com/oauth2-proxy/oauth2-proxy/providers" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -332,7 +333,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex | ||||||
| 	cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) | 	cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) | ||||||
| 
 | 
 | ||||||
| 	if cookieDomain != "" { | 	if cookieDomain != "" { | ||||||
| 		domain := cookies.GetRequestHost(req) | 		domain := util.GetRequestHost(req) | ||||||
| 		if h, _, err := net.SplitHostPort(domain); err == nil { | 		if h, _, err := net.SplitHostPort(domain); err == nil { | ||||||
| 			domain = h | 			domain = h | ||||||
| 		} | 		} | ||||||
|  | @ -747,7 +748,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	redirectURI := p.GetRedirectURI(req.Host) | 	redirectURI := p.GetRedirectURI(util.GetRequestHost(req)) | ||||||
| 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) | 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -770,7 +771,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code")) | 	session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) | 		logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | ||||||
|  |  | ||||||
|  | @ -9,13 +9,14 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // MakeCookie constructs a cookie from the given parameters,
 | // MakeCookie constructs a cookie from the given parameters,
 | ||||||
| // discovering the domain from the request if not specified.
 | // discovering the domain from the request if not specified.
 | ||||||
| func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie { | func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie { | ||||||
| 	if domain != "" { | 	if domain != "" { | ||||||
| 		host := req.Host | 		host := util.GetRequestHost(req) | ||||||
| 		if h, _, err := net.SplitHostPort(host); err == nil { | 		if h, _, err := net.SplitHostPort(host); err == nil { | ||||||
| 			host = h | 			host = h | ||||||
| 		} | 		} | ||||||
|  | @ -47,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO | ||||||
| 	// If nothing matches, create the cookie with the shortest domain
 | 	// If nothing matches, create the cookie with the shortest domain
 | ||||||
| 	defaultDomain := "" | 	defaultDomain := "" | ||||||
| 	if len(cookieOpts.Domains) > 0 { | 	if len(cookieOpts.Domains) > 0 { | ||||||
| 		logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) | 		logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) | ||||||
| 		defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1] | 		defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1] | ||||||
| 	} | 	} | ||||||
| 	return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite)) | 	return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite)) | ||||||
|  | @ -56,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO | ||||||
| // GetCookieDomain returns the correct cookie domain given a list of domains
 | // GetCookieDomain returns the correct cookie domain given a list of domains
 | ||||||
| // by checking the X-Fowarded-Host and host header of an an http request
 | // by checking the X-Fowarded-Host and host header of an an http request
 | ||||||
| func GetCookieDomain(req *http.Request, cookieDomains []string) string { | func GetCookieDomain(req *http.Request, cookieDomains []string) string { | ||||||
| 	host := GetRequestHost(req) | 	host := util.GetRequestHost(req) | ||||||
| 	for _, domain := range cookieDomains { | 	for _, domain := range cookieDomains { | ||||||
| 		if strings.HasSuffix(host, domain) { | 		if strings.HasSuffix(host, domain) { | ||||||
| 			return domain | 			return domain | ||||||
|  | @ -65,15 +66,6 @@ func GetCookieDomain(req *http.Request, cookieDomains []string) string { | ||||||
| 	return "" | 	return "" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetRequestHost return the request host header or X-Forwarded-Host if present
 |  | ||||||
| func GetRequestHost(req *http.Request) string { |  | ||||||
| 	host := req.Header.Get("X-Forwarded-Host") |  | ||||||
| 	if host == "" { |  | ||||||
| 		host = req.Host |  | ||||||
| 	} |  | ||||||
| 	return host |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Parse a valid http.SameSite value from a user supplied string for use of making cookies.
 | // Parse a valid http.SameSite value from a user supplied string for use of making cookies.
 | ||||||
| func ParseSameSite(v string) http.SameSite { | func ParseSameSite(v string) http.SameSite { | ||||||
| 	switch v { | 	switch v { | ||||||
|  |  | ||||||
|  | @ -11,6 +11,8 @@ import ( | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"text/template" | 	"text/template" | ||||||
| 	"time" | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // AuthStatus defines the different types of auth logging that occur
 | // AuthStatus defines the different types of auth logging that occur
 | ||||||
|  | @ -195,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu | ||||||
| 
 | 
 | ||||||
| 	err := l.authTemplate.Execute(l.writer, authLogMessageData{ | 	err := l.authTemplate.Execute(l.writer, authLogMessageData{ | ||||||
| 		Client:        client, | 		Client:        client, | ||||||
| 		Host:          req.Host, | 		Host:          util.GetRequestHost(req), | ||||||
| 		Protocol:      req.Proto, | 		Protocol:      req.Proto, | ||||||
| 		RequestMethod: req.Method, | 		RequestMethod: req.Method, | ||||||
| 		Timestamp:     FormatTimestamp(now), | 		Timestamp:     FormatTimestamp(now), | ||||||
|  | @ -249,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. | ||||||
| 
 | 
 | ||||||
| 	err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ | 	err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ | ||||||
| 		Client:          client, | 		Client:          client, | ||||||
| 		Host:            req.Host, | 		Host:            util.GetRequestHost(req), | ||||||
| 		Protocol:        req.Proto, | 		Protocol:        req.Proto, | ||||||
| 		RequestDuration: fmt.Sprintf("%0.3f", duration), | 		RequestDuration: fmt.Sprintf("%0.3f", duration), | ||||||
| 		RequestMethod:   req.Method, | 		RequestMethod:   req.Method, | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/justinas/alice" | 	"github.com/justinas/alice" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const httpsScheme = "https" | const httpsScheme = "https" | ||||||
|  | @ -38,10 +39,9 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { | ||||||
| 		// Set the scheme to HTTPS
 | 		// Set the scheme to HTTPS
 | ||||||
| 		targetURL.Scheme = httpsScheme | 		targetURL.Scheme = httpsScheme | ||||||
| 
 | 
 | ||||||
| 		// Set the req.Host when the targetURL still does not have one
 | 		// Set the Host in case the targetURL still does not have one
 | ||||||
| 		if targetURL.Host == "" { | 		// or it isn't X-Forwarded-Host aware
 | ||||||
| 			targetURL.Host = req.Host | 		targetURL.Host = util.GetRequestHost(req) | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		// Overwrite the port if the original request was to a non-standard port
 | 		// Overwrite the port if the original request was to a non-standard port
 | ||||||
| 		if targetURL.Port() != "" { | 		if targetURL.Port() != "" { | ||||||
|  |  | ||||||
|  | @ -164,5 +164,16 @@ var _ = Describe("RedirectToHTTPS suite", func() { | ||||||
| 			expectedBody:     permanentRedirectBody("https://example.com/"), | 			expectedBody:     permanentRedirectBody("https://example.com/"), | ||||||
| 			expectedLocation: "https://example.com/", | 			expectedLocation: "https://example.com/", | ||||||
| 		}), | 		}), | ||||||
|  | 		Entry("without TLS with an X-Forwarded-Host header", &requestTableInput{ | ||||||
|  | 			requestString: "http://internal.example.com", | ||||||
|  | 			useTLS:        false, | ||||||
|  | 			headers: map[string]string{ | ||||||
|  | 				"X-Forwarded-Proto": "HTTP", | ||||||
|  | 				"X-Forwarded-Host":  "external.example.com", | ||||||
|  | 			}, | ||||||
|  | 			expectedStatus:   308, | ||||||
|  | 			expectedBody:     permanentRedirectBody("https://external.example.com"), | ||||||
|  | 			expectedLocation: "https://external.example.com", | ||||||
|  | 		}), | ||||||
| 	) | 	) | ||||||
| }) | }) | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ import ( | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func GetCertPool(paths []string) (*x509.CertPool, error) { | func GetCertPool(paths []string) (*x509.CertPool, error) { | ||||||
|  | @ -23,3 +24,12 @@ func GetCertPool(paths []string) (*x509.CertPool, error) { | ||||||
| 	} | 	} | ||||||
| 	return pool, nil | 	return pool, nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // GetRequestHost return the request host header or X-Forwarded-Host if present
 | ||||||
|  | func GetRequestHost(req *http.Request) string { | ||||||
|  | 	host := req.Header.Get("X-Forwarded-Host") | ||||||
|  | 	if host == "" { | ||||||
|  | 		host = req.Host | ||||||
|  | 	} | ||||||
|  | 	return host | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -4,9 +4,11 @@ import ( | ||||||
| 	"crypto/x509/pkix" | 	"crypto/x509/pkix" | ||||||
| 	"encoding/asn1" | 	"encoding/asn1" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
|  | 	"net/http/httptest" | ||||||
| 	"os" | 	"os" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -70,7 +72,13 @@ func TestGetCertPool_NoRoots(t *testing.T) { | ||||||
| func TestGetCertPool(t *testing.T) { | func TestGetCertPool(t *testing.T) { | ||||||
| 	tempDir, err := ioutil.TempDir("", "certtest") | 	tempDir, err := ioutil.TempDir("", "certtest") | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 	defer os.RemoveAll(tempDir) | 	defer func(path string) { | ||||||
|  | 		rerr := os.RemoveAll(path) | ||||||
|  | 		if rerr != nil { | ||||||
|  | 			panic(rerr) | ||||||
|  | 		} | ||||||
|  | 	}(tempDir) | ||||||
|  | 
 | ||||||
| 	certFile1 := makeTestCertFile(t, testCA1, tempDir) | 	certFile1 := makeTestCertFile(t, testCA1, tempDir) | ||||||
| 	certFile2 := makeTestCertFile(t, testCA2, tempDir) | 	certFile2 := makeTestCertFile(t, testCA2, tempDir) | ||||||
| 
 | 
 | ||||||
|  | @ -89,3 +97,16 @@ func TestGetCertPool(t *testing.T) { | ||||||
| 	expectedSubjects := []string{testCA1Subj, testCA2Subj} | 	expectedSubjects := []string{testCA1Subj, testCA2Subj} | ||||||
| 	assert.Equal(t, expectedSubjects, got) | 	assert.Equal(t, expectedSubjects, got) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestGetRequestHost(t *testing.T) { | ||||||
|  | 	g := NewWithT(t) | ||||||
|  | 
 | ||||||
|  | 	req := httptest.NewRequest("GET", "https://example.com", nil) | ||||||
|  | 	host := GetRequestHost(req) | ||||||
|  | 	g.Expect(host).To(Equal("example.com")) | ||||||
|  | 
 | ||||||
|  | 	proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil) | ||||||
|  | 	proxyReq.Header.Add("X-Forwarded-Host", "external.example.com") | ||||||
|  | 	extHost := GetRequestHost(proxyReq) | ||||||
|  | 	g.Expect(extHost).To(Equal("external.example.com")) | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue