diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a158aab..33db5210 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v6.0.0 +- [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed) - [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed) - [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed) - [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed) diff --git a/http.go b/http.go index a8694187..7e1215f7 100644 --- a/http.go +++ b/http.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/justinas/alice" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" ) @@ -129,20 +128,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { tc.SetKeepAlivePeriod(3 * time.Minute) return tc, nil } - -func newRedirectToHTTPS(opts *options.Options) alice.Constructor { - return func(next http.Handler) http.Handler { - return redirectToHTTPS(opts, next) - } -} - -func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proto := r.Header.Get("X-Forwarded-Proto") - if opts.ForceHTTPS && (r.TLS == nil || (proto != "" && strings.ToLower(proto) != "https")) { - http.Redirect(w, r, opts.HTTPSAddress, http.StatusPermanentRedirect) - } - - h.ServeHTTP(w, r) - }) -} diff --git a/http_test.go b/http_test.go index 5bd58172..ba516b97 100644 --- a/http_test.go +++ b/http_test.go @@ -2,7 +2,6 @@ package main import ( "net/http" - "net/http/httptest" "sync" "testing" "time" @@ -11,56 +10,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRedirectToHTTPSTrue(t *testing.T) { - opts := options.NewOptions() - opts.ForceHTTPS = true - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := redirectToHTTPS(opts, http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/", nil) - h.ServeHTTP(rw, r) - - assert.Equal(t, http.StatusPermanentRedirect, rw.Code, "status code should be %d, got: %d", http.StatusPermanentRedirect, rw.Code) -} - -func TestRedirectToHTTPSFalse(t *testing.T) { - opts := options.NewOptions() - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := redirectToHTTPS(opts, http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/", nil) - h.ServeHTTP(rw, r) - - assert.Equal(t, http.StatusOK, rw.Code, "status code should be %d, got: %d", http.StatusOK, rw.Code) -} - -func TestRedirectNotWhenHTTPS(t *testing.T) { - opts := options.NewOptions() - opts.ForceHTTPS = true - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := redirectToHTTPS(opts, http.HandlerFunc(handler)) - s := httptest.NewTLSServer(h) - defer s.Close() - - opts.HTTPSAddress = s.URL - client := s.Client() - res, err := client.Get(s.URL) - if err != nil { - t.Fatalf("request to test server failed with error: %v", err) - } - - assert.Equal(t, http.StatusOK, res.StatusCode, "status code should be %d, got: %d", http.StatusOK, res.StatusCode) -} - func TestGracefulShutdown(t *testing.T) { opts := options.NewOptions() stop := make(chan struct{}, 1) diff --git a/main.go b/main.go index 0720ba86..4ee6af4b 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "math/rand" + "net" "os" "os/signal" "runtime" @@ -79,7 +80,11 @@ func main() { chain := alice.New() if opts.ForceHTTPS { - chain = chain.Append(newRedirectToHTTPS(opts)) + _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) + if err != nil { + logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err) + } + chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) } healthCheckPaths := []string{opts.PingPath} diff --git a/pkg/middleware/middleware_suite_test.go b/pkg/middleware/middleware_suite_test.go index 9972ce8f..204a9798 100644 --- a/pkg/middleware/middleware_suite_test.go +++ b/pkg/middleware/middleware_suite_test.go @@ -1,6 +1,7 @@ package middleware import ( + "net/http" "testing" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" @@ -14,3 +15,9 @@ func TestMiddlewareSuite(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Middleware") } + +func testHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Write([]byte("test")) + }) +} diff --git a/pkg/middleware/redirect_to_https.go b/pkg/middleware/redirect_to_https.go new file mode 100644 index 00000000..141d076a --- /dev/null +++ b/pkg/middleware/redirect_to_https.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "net" + "net/http" + "net/url" + "strings" + + "github.com/justinas/alice" +) + +const httpsScheme = "https" + +// NewRedirectToHTTPS creates a new redirectToHTTPS middleware that will redirect +// HTTP requests to HTTPS +func NewRedirectToHTTPS(httpsPort string) alice.Constructor { + return func(next http.Handler) http.Handler { + return redirectToHTTPS(httpsPort, next) + } +} + +// redirectToHTTPS is an HTTP middleware the will redirect a request to HTTPS +// if it is not already HTTPS. +// If the request is to a non standard port, the redirection request will be +// to the port from the httpsAddress given. +func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + proto := req.Header.Get("X-Forwarded-Proto") + if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") { + // Only care about the connection to us being HTTPS if the proto is empty, + // otherwise the proto is source of truth + next.ServeHTTP(rw, req) + return + } + + // Copy the request URL + targetURL, _ := url.Parse(req.URL.String()) + // Set the scheme to HTTPS + targetURL.Scheme = httpsScheme + + // Overwrite the port if the original request was to a non-standard port + if targetURL.Port() != "" { + // If Port was not empty, this should be fine to ignore the error + host, _, _ := net.SplitHostPort(targetURL.Host) + targetURL.Host = net.JoinHostPort(host, httpsPort) + } + + http.Redirect(rw, req, targetURL.String(), http.StatusPermanentRedirect) + }) +} diff --git a/pkg/middleware/redirect_to_https_test.go b/pkg/middleware/redirect_to_https_test.go new file mode 100644 index 00000000..c4b631b5 --- /dev/null +++ b/pkg/middleware/redirect_to_https_test.go @@ -0,0 +1,158 @@ +package middleware + +import ( + "crypto/tls" + "fmt" + "net/http/httptest" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("RedirectToHTTPS suite", func() { + const httpsPort = "8443" + + var permanentRedirectBody = func(address string) string { + return fmt.Sprintf("Permanent Redirect.\n\n", address) + } + + type requestTableInput struct { + requestString string + useTLS bool + headers map[string]string + expectedStatus int + expectedBody string + expectedLocation string + } + + DescribeTable("when serving a request", + func(in *requestTableInput) { + req := httptest.NewRequest("", in.requestString, nil) + for k, v := range in.headers { + req.Header.Add(k, v) + } + if in.useTLS { + req.TLS = &tls.ConnectionState{} + } + + rw := httptest.NewRecorder() + + handler := NewRedirectToHTTPS(httpsPort)(testHandler()) + handler.ServeHTTP(rw, req) + + Expect(rw.Code).To(Equal(in.expectedStatus)) + Expect(rw.Body.String()).To(Equal(in.expectedBody)) + + if in.expectedLocation != "" { + Expect(rw.Header().Values("Location")).To(ConsistOf(in.expectedLocation)) + } + }, + Entry("without TLS", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{}, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("with TLS", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{}, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("without TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTPS", + }, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTPS", + }, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("without TLS and X-Forwarded-Proto=https", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "https", + }, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("with TLS and X-Forwarded-Proto=https", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{ + "X-Forwarded-Proto": "https", + }, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("without TLS and X-Forwarded-Proto=HTTP", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTP", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("with TLS and X-Forwarded-Proto=HTTP", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTP", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("without TLS and X-Forwarded-Proto=http", &requestTableInput{ + requestString: "https://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "http", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("with TLS and X-Forwarded-Proto=http", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{ + "X-Forwarded-Proto": "http", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("without TLS on a non-standard port", &requestTableInput{ + requestString: "http://example.com:8080", + useTLS: false, + headers: map[string]string{}, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com:8443"), + expectedLocation: "https://example.com:8443", + }), + Entry("with TLS on a non-standard port", &requestTableInput{ + requestString: "https://example.com:8443", + useTLS: true, + headers: map[string]string{}, + expectedStatus: 200, + expectedBody: "test", + }), + ) +})