Use X-Forwarded-Host consistently
This commit is contained in:
parent
bd5fab478d
commit
29b24793e3
|
|
@ -8,6 +8,8 @@
|
||||||
|
|
||||||
## Changes since v6.1.0
|
## Changes since v6.1.0
|
||||||
|
|
||||||
|
- [#729](https://github.com/oauth2-proxy/oauth2-proxy/pull/729) Use X-Forwarded-Host consistently when set (@NickMeves)
|
||||||
|
|
||||||
# v6.1.0
|
# v6.1.0
|
||||||
|
|
||||||
## Release Highlights
|
## Release Highlights
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ import (
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/util"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/providers"
|
"github.com/oauth2-proxy/oauth2-proxy/providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -332,7 +333,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
|
||||||
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
|
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
|
||||||
|
|
||||||
if cookieDomain != "" {
|
if cookieDomain != "" {
|
||||||
domain := cookies.GetRequestHost(req)
|
domain := util.GetRequestHost(req)
|
||||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||||
domain = h
|
domain = h
|
||||||
}
|
}
|
||||||
|
|
@ -747,7 +748,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redirectURI := p.GetRedirectURI(req.Host)
|
redirectURI := p.GetRedirectURI(util.GetRequestHost(req))
|
||||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
|
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -770,7 +771,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code"))
|
session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
|
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
|
||||||
|
|
|
||||||
|
|
@ -9,13 +9,14 @@ import (
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MakeCookie constructs a cookie from the given parameters,
|
// MakeCookie constructs a cookie from the given parameters,
|
||||||
// discovering the domain from the request if not specified.
|
// discovering the domain from the request if not specified.
|
||||||
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
|
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
|
||||||
if domain != "" {
|
if domain != "" {
|
||||||
host := req.Host
|
host := util.GetRequestHost(req)
|
||||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||||
host = h
|
host = h
|
||||||
}
|
}
|
||||||
|
|
@ -47,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
||||||
// If nothing matches, create the cookie with the shortest domain
|
// If nothing matches, create the cookie with the shortest domain
|
||||||
defaultDomain := ""
|
defaultDomain := ""
|
||||||
if len(cookieOpts.Domains) > 0 {
|
if len(cookieOpts.Domains) > 0 {
|
||||||
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
||||||
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
|
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
|
||||||
}
|
}
|
||||||
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
|
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
|
||||||
|
|
@ -56,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
||||||
// GetCookieDomain returns the correct cookie domain given a list of domains
|
// GetCookieDomain returns the correct cookie domain given a list of domains
|
||||||
// by checking the X-Fowarded-Host and host header of an an http request
|
// by checking the X-Fowarded-Host and host header of an an http request
|
||||||
func GetCookieDomain(req *http.Request, cookieDomains []string) string {
|
func GetCookieDomain(req *http.Request, cookieDomains []string) string {
|
||||||
host := GetRequestHost(req)
|
host := util.GetRequestHost(req)
|
||||||
for _, domain := range cookieDomains {
|
for _, domain := range cookieDomains {
|
||||||
if strings.HasSuffix(host, domain) {
|
if strings.HasSuffix(host, domain) {
|
||||||
return domain
|
return domain
|
||||||
|
|
@ -65,15 +66,6 @@ func GetCookieDomain(req *http.Request, cookieDomains []string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRequestHost return the request host header or X-Forwarded-Host if present
|
|
||||||
func GetRequestHost(req *http.Request) string {
|
|
||||||
host := req.Header.Get("X-Forwarded-Host")
|
|
||||||
if host == "" {
|
|
||||||
host = req.Host
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse a valid http.SameSite value from a user supplied string for use of making cookies.
|
// Parse a valid http.SameSite value from a user supplied string for use of making cookies.
|
||||||
func ParseSameSite(v string) http.SameSite {
|
func ParseSameSite(v string) http.SameSite {
|
||||||
switch v {
|
switch v {
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthStatus defines the different types of auth logging that occur
|
// AuthStatus defines the different types of auth logging that occur
|
||||||
|
|
@ -195,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
|
||||||
|
|
||||||
err := l.authTemplate.Execute(l.writer, authLogMessageData{
|
err := l.authTemplate.Execute(l.writer, authLogMessageData{
|
||||||
Client: client,
|
Client: client,
|
||||||
Host: req.Host,
|
Host: util.GetRequestHost(req),
|
||||||
Protocol: req.Proto,
|
Protocol: req.Proto,
|
||||||
RequestMethod: req.Method,
|
RequestMethod: req.Method,
|
||||||
Timestamp: FormatTimestamp(now),
|
Timestamp: FormatTimestamp(now),
|
||||||
|
|
@ -249,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
|
||||||
|
|
||||||
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
|
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
|
||||||
Client: client,
|
Client: client,
|
||||||
Host: req.Host,
|
Host: util.GetRequestHost(req),
|
||||||
Protocol: req.Proto,
|
Protocol: req.Proto,
|
||||||
RequestDuration: fmt.Sprintf("%0.3f", duration),
|
RequestDuration: fmt.Sprintf("%0.3f", duration),
|
||||||
RequestMethod: req.Method,
|
RequestMethod: req.Method,
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
"github.com/justinas/alice"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const httpsScheme = "https"
|
const httpsScheme = "https"
|
||||||
|
|
@ -38,10 +39,9 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||||
// Set the scheme to HTTPS
|
// Set the scheme to HTTPS
|
||||||
targetURL.Scheme = httpsScheme
|
targetURL.Scheme = httpsScheme
|
||||||
|
|
||||||
// Set the req.Host when the targetURL still does not have one
|
// Set the Host in case the targetURL still does not have one
|
||||||
if targetURL.Host == "" {
|
// or it isn't X-Forwarded-Host aware
|
||||||
targetURL.Host = req.Host
|
targetURL.Host = util.GetRequestHost(req)
|
||||||
}
|
|
||||||
|
|
||||||
// Overwrite the port if the original request was to a non-standard port
|
// Overwrite the port if the original request was to a non-standard port
|
||||||
if targetURL.Port() != "" {
|
if targetURL.Port() != "" {
|
||||||
|
|
|
||||||
|
|
@ -164,5 +164,16 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
expectedBody: permanentRedirectBody("https://example.com/"),
|
expectedBody: permanentRedirectBody("https://example.com/"),
|
||||||
expectedLocation: "https://example.com/",
|
expectedLocation: "https://example.com/",
|
||||||
}),
|
}),
|
||||||
|
Entry("without TLS with an X-Forwarded-Host header", &requestTableInput{
|
||||||
|
requestString: "http://internal.example.com",
|
||||||
|
useTLS: false,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "HTTP",
|
||||||
|
"X-Forwarded-Host": "external.example.com",
|
||||||
|
},
|
||||||
|
expectedStatus: 308,
|
||||||
|
expectedBody: permanentRedirectBody("https://external.example.com"),
|
||||||
|
expectedLocation: "https://external.example.com",
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetCertPool(paths []string) (*x509.CertPool, error) {
|
func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||||
|
|
@ -23,3 +24,12 @@ func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||||
}
|
}
|
||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRequestHost return the request host header or X-Forwarded-Host if present
|
||||||
|
func GetRequestHost(req *http.Request) string {
|
||||||
|
host := req.Header.Get("X-Forwarded-Host")
|
||||||
|
if host == "" {
|
||||||
|
host = req.Host
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,11 @@ import (
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -70,7 +72,13 @@ func TestGetCertPool_NoRoots(t *testing.T) {
|
||||||
func TestGetCertPool(t *testing.T) {
|
func TestGetCertPool(t *testing.T) {
|
||||||
tempDir, err := ioutil.TempDir("", "certtest")
|
tempDir, err := ioutil.TempDir("", "certtest")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer os.RemoveAll(tempDir)
|
defer func(path string) {
|
||||||
|
rerr := os.RemoveAll(path)
|
||||||
|
if rerr != nil {
|
||||||
|
panic(rerr)
|
||||||
|
}
|
||||||
|
}(tempDir)
|
||||||
|
|
||||||
certFile1 := makeTestCertFile(t, testCA1, tempDir)
|
certFile1 := makeTestCertFile(t, testCA1, tempDir)
|
||||||
certFile2 := makeTestCertFile(t, testCA2, tempDir)
|
certFile2 := makeTestCertFile(t, testCA2, tempDir)
|
||||||
|
|
||||||
|
|
@ -89,3 +97,16 @@ func TestGetCertPool(t *testing.T) {
|
||||||
expectedSubjects := []string{testCA1Subj, testCA2Subj}
|
expectedSubjects := []string{testCA1Subj, testCA2Subj}
|
||||||
assert.Equal(t, expectedSubjects, got)
|
assert.Equal(t, expectedSubjects, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetRequestHost(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "https://example.com", nil)
|
||||||
|
host := GetRequestHost(req)
|
||||||
|
g.Expect(host).To(Equal("example.com"))
|
||||||
|
|
||||||
|
proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil)
|
||||||
|
proxyReq.Header.Add("X-Forwarded-Host", "external.example.com")
|
||||||
|
extHost := GetRequestHost(proxyReq)
|
||||||
|
g.Expect(extHost).To(Equal("external.example.com"))
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue