Use X-Forwarded-Host consistently

This commit is contained in:
Nick Meves 2020-08-21 19:50:32 -07:00
parent bd5fab478d
commit 29b24793e3
No known key found for this signature in database
GPG Key ID: 93BA8A3CEDCDD1CF
8 changed files with 61 additions and 22 deletions

View File

@ -8,6 +8,8 @@
## Changes since v6.1.0 ## Changes since v6.1.0
- [#729](https://github.com/oauth2-proxy/oauth2-proxy/pull/729) Use X-Forwarded-Host consistently when set (@NickMeves)
# v6.1.0 # v6.1.0
## Release Highlights ## Release Highlights

View File

@ -28,6 +28,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
"github.com/oauth2-proxy/oauth2-proxy/pkg/util"
"github.com/oauth2-proxy/oauth2-proxy/providers" "github.com/oauth2-proxy/oauth2-proxy/providers"
) )
@ -332,7 +333,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
if cookieDomain != "" { if cookieDomain != "" {
domain := cookies.GetRequestHost(req) domain := util.GetRequestHost(req)
if h, _, err := net.SplitHostPort(domain); err == nil { if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h domain = h
} }
@ -747,7 +748,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
return return
} }
redirectURI := p.GetRedirectURI(req.Host) redirectURI := p.GetRedirectURI(util.GetRequestHost(req))
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
} }
@ -770,7 +771,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return return
} }
session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code")) session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code"))
if err != nil { if err != nil {
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")

View File

@ -9,13 +9,14 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/util"
) )
// MakeCookie constructs a cookie from the given parameters, // MakeCookie constructs a cookie from the given parameters,
// discovering the domain from the request if not specified. // discovering the domain from the request if not specified.
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie { func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
if domain != "" { if domain != "" {
host := req.Host host := util.GetRequestHost(req)
if h, _, err := net.SplitHostPort(host); err == nil { if h, _, err := net.SplitHostPort(host); err == nil {
host = h host = h
} }
@ -47,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
// If nothing matches, create the cookie with the shortest domain // If nothing matches, create the cookie with the shortest domain
defaultDomain := "" defaultDomain := ""
if len(cookieOpts.Domains) > 0 { if len(cookieOpts.Domains) > 0 {
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1] defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
} }
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite)) return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
@ -56,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
// GetCookieDomain returns the correct cookie domain given a list of domains // GetCookieDomain returns the correct cookie domain given a list of domains
// by checking the X-Fowarded-Host and host header of an an http request // by checking the X-Fowarded-Host and host header of an an http request
func GetCookieDomain(req *http.Request, cookieDomains []string) string { func GetCookieDomain(req *http.Request, cookieDomains []string) string {
host := GetRequestHost(req) host := util.GetRequestHost(req)
for _, domain := range cookieDomains { for _, domain := range cookieDomains {
if strings.HasSuffix(host, domain) { if strings.HasSuffix(host, domain) {
return domain return domain
@ -65,15 +66,6 @@ func GetCookieDomain(req *http.Request, cookieDomains []string) string {
return "" return ""
} }
// GetRequestHost return the request host header or X-Forwarded-Host if present
func GetRequestHost(req *http.Request) string {
host := req.Header.Get("X-Forwarded-Host")
if host == "" {
host = req.Host
}
return host
}
// Parse a valid http.SameSite value from a user supplied string for use of making cookies. // Parse a valid http.SameSite value from a user supplied string for use of making cookies.
func ParseSameSite(v string) http.SameSite { func ParseSameSite(v string) http.SameSite {
switch v { switch v {

View File

@ -11,6 +11,8 @@ import (
"sync" "sync"
"text/template" "text/template"
"time" "time"
"github.com/oauth2-proxy/oauth2-proxy/pkg/util"
) )
// AuthStatus defines the different types of auth logging that occur // AuthStatus defines the different types of auth logging that occur
@ -195,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
err := l.authTemplate.Execute(l.writer, authLogMessageData{ err := l.authTemplate.Execute(l.writer, authLogMessageData{
Client: client, Client: client,
Host: req.Host, Host: util.GetRequestHost(req),
Protocol: req.Proto, Protocol: req.Proto,
RequestMethod: req.Method, RequestMethod: req.Method,
Timestamp: FormatTimestamp(now), Timestamp: FormatTimestamp(now),
@ -249,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
Client: client, Client: client,
Host: req.Host, Host: util.GetRequestHost(req),
Protocol: req.Proto, Protocol: req.Proto,
RequestDuration: fmt.Sprintf("%0.3f", duration), RequestDuration: fmt.Sprintf("%0.3f", duration),
RequestMethod: req.Method, RequestMethod: req.Method,

View File

@ -7,6 +7,7 @@ import (
"strings" "strings"
"github.com/justinas/alice" "github.com/justinas/alice"
"github.com/oauth2-proxy/oauth2-proxy/pkg/util"
) )
const httpsScheme = "https" const httpsScheme = "https"
@ -38,10 +39,9 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
// Set the scheme to HTTPS // Set the scheme to HTTPS
targetURL.Scheme = httpsScheme targetURL.Scheme = httpsScheme
// Set the req.Host when the targetURL still does not have one // Set the Host in case the targetURL still does not have one
if targetURL.Host == "" { // or it isn't X-Forwarded-Host aware
targetURL.Host = req.Host targetURL.Host = util.GetRequestHost(req)
}
// Overwrite the port if the original request was to a non-standard port // Overwrite the port if the original request was to a non-standard port
if targetURL.Port() != "" { if targetURL.Port() != "" {

View File

@ -164,5 +164,16 @@ var _ = Describe("RedirectToHTTPS suite", func() {
expectedBody: permanentRedirectBody("https://example.com/"), expectedBody: permanentRedirectBody("https://example.com/"),
expectedLocation: "https://example.com/", expectedLocation: "https://example.com/",
}), }),
Entry("without TLS with an X-Forwarded-Host header", &requestTableInput{
requestString: "http://internal.example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "HTTP",
"X-Forwarded-Host": "external.example.com",
},
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://external.example.com"),
expectedLocation: "https://external.example.com",
}),
) )
}) })

View File

@ -4,6 +4,7 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http"
) )
func GetCertPool(paths []string) (*x509.CertPool, error) { func GetCertPool(paths []string) (*x509.CertPool, error) {
@ -23,3 +24,12 @@ func GetCertPool(paths []string) (*x509.CertPool, error) {
} }
return pool, nil return pool, nil
} }
// GetRequestHost return the request host header or X-Forwarded-Host if present
func GetRequestHost(req *http.Request) string {
host := req.Header.Get("X-Forwarded-Host")
if host == "" {
host = req.Host
}
return host
}

View File

@ -4,9 +4,11 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"io/ioutil" "io/ioutil"
"net/http/httptest"
"os" "os"
"testing" "testing"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -70,7 +72,13 @@ func TestGetCertPool_NoRoots(t *testing.T) {
func TestGetCertPool(t *testing.T) { func TestGetCertPool(t *testing.T) {
tempDir, err := ioutil.TempDir("", "certtest") tempDir, err := ioutil.TempDir("", "certtest")
assert.NoError(t, err) assert.NoError(t, err)
defer os.RemoveAll(tempDir) defer func(path string) {
rerr := os.RemoveAll(path)
if rerr != nil {
panic(rerr)
}
}(tempDir)
certFile1 := makeTestCertFile(t, testCA1, tempDir) certFile1 := makeTestCertFile(t, testCA1, tempDir)
certFile2 := makeTestCertFile(t, testCA2, tempDir) certFile2 := makeTestCertFile(t, testCA2, tempDir)
@ -89,3 +97,16 @@ func TestGetCertPool(t *testing.T) {
expectedSubjects := []string{testCA1Subj, testCA2Subj} expectedSubjects := []string{testCA1Subj, testCA2Subj}
assert.Equal(t, expectedSubjects, got) assert.Equal(t, expectedSubjects, got)
} }
func TestGetRequestHost(t *testing.T) {
g := NewWithT(t)
req := httptest.NewRequest("GET", "https://example.com", nil)
host := GetRequestHost(req)
g.Expect(host).To(Equal("example.com"))
proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil)
proxyReq.Header.Add("X-Forwarded-Host", "external.example.com")
extHost := GetRequestHost(proxyReq)
g.Expect(extHost).To(Equal("external.example.com"))
}