From 1c1106721eff8f2414a0e87141b80494a36a9f6c Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Fri, 12 Jun 2020 18:18:41 +0100 Subject: [PATCH] Move RedirectToHTTPS to middleware package Moves the logic for redirecting to HTTPs to a middleware package and adds tests for this logic. Also makes the functionality more useful, previously it always redirected to the HTTPS address of the proxy, which may not have been intended, now it will redirect based on if a port is provided in the URL (assume public facing 80 to 443 or 4180 to 8443 for example) --- CHANGELOG.md | 1 + http.go | 18 --- http_test.go | 51 -------- main.go | 7 +- pkg/middleware/middleware_suite_test.go | 7 + pkg/middleware/redirect_to_https.go | 50 +++++++ pkg/middleware/redirect_to_https_test.go | 158 +++++++++++++++++++++++ 7 files changed, 222 insertions(+), 70 deletions(-) create mode 100644 pkg/middleware/redirect_to_https.go create mode 100644 pkg/middleware/redirect_to_https_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a158aab..33db5210 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v6.0.0 +- [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed) - [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed) - [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed) - [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed) diff --git a/http.go b/http.go index a8694187..7e1215f7 100644 --- a/http.go +++ b/http.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/justinas/alice" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" ) @@ -129,20 +128,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { tc.SetKeepAlivePeriod(3 * time.Minute) return tc, nil } - -func newRedirectToHTTPS(opts *options.Options) alice.Constructor { - return func(next http.Handler) http.Handler { - return redirectToHTTPS(opts, next) - } -} - -func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proto := r.Header.Get("X-Forwarded-Proto") - if opts.ForceHTTPS && (r.TLS == nil || (proto != "" && strings.ToLower(proto) != "https")) { - http.Redirect(w, r, opts.HTTPSAddress, http.StatusPermanentRedirect) - } - - h.ServeHTTP(w, r) - }) -} diff --git a/http_test.go b/http_test.go index 5bd58172..ba516b97 100644 --- a/http_test.go +++ b/http_test.go @@ -2,7 +2,6 @@ package main import ( "net/http" - "net/http/httptest" "sync" "testing" "time" @@ -11,56 +10,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRedirectToHTTPSTrue(t *testing.T) { - opts := options.NewOptions() - opts.ForceHTTPS = true - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := redirectToHTTPS(opts, http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/", nil) - h.ServeHTTP(rw, r) - - assert.Equal(t, http.StatusPermanentRedirect, rw.Code, "status code should be %d, got: %d", http.StatusPermanentRedirect, rw.Code) -} - -func TestRedirectToHTTPSFalse(t *testing.T) { - opts := options.NewOptions() - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := redirectToHTTPS(opts, http.HandlerFunc(handler)) - rw := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/", nil) - h.ServeHTTP(rw, r) - - assert.Equal(t, http.StatusOK, rw.Code, "status code should be %d, got: %d", http.StatusOK, rw.Code) -} - -func TestRedirectNotWhenHTTPS(t *testing.T) { - opts := options.NewOptions() - opts.ForceHTTPS = true - handler := func(w http.ResponseWriter, req *http.Request) { - w.Write([]byte("test")) - } - - h := redirectToHTTPS(opts, http.HandlerFunc(handler)) - s := httptest.NewTLSServer(h) - defer s.Close() - - opts.HTTPSAddress = s.URL - client := s.Client() - res, err := client.Get(s.URL) - if err != nil { - t.Fatalf("request to test server failed with error: %v", err) - } - - assert.Equal(t, http.StatusOK, res.StatusCode, "status code should be %d, got: %d", http.StatusOK, res.StatusCode) -} - func TestGracefulShutdown(t *testing.T) { opts := options.NewOptions() stop := make(chan struct{}, 1) diff --git a/main.go b/main.go index 0720ba86..4ee6af4b 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "math/rand" + "net" "os" "os/signal" "runtime" @@ -79,7 +80,11 @@ func main() { chain := alice.New() if opts.ForceHTTPS { - chain = chain.Append(newRedirectToHTTPS(opts)) + _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) + if err != nil { + logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err) + } + chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) } healthCheckPaths := []string{opts.PingPath} diff --git a/pkg/middleware/middleware_suite_test.go b/pkg/middleware/middleware_suite_test.go index 9972ce8f..204a9798 100644 --- a/pkg/middleware/middleware_suite_test.go +++ b/pkg/middleware/middleware_suite_test.go @@ -1,6 +1,7 @@ package middleware import ( + "net/http" "testing" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" @@ -14,3 +15,9 @@ func TestMiddlewareSuite(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Middleware") } + +func testHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Write([]byte("test")) + }) +} diff --git a/pkg/middleware/redirect_to_https.go b/pkg/middleware/redirect_to_https.go new file mode 100644 index 00000000..141d076a --- /dev/null +++ b/pkg/middleware/redirect_to_https.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "net" + "net/http" + "net/url" + "strings" + + "github.com/justinas/alice" +) + +const httpsScheme = "https" + +// NewRedirectToHTTPS creates a new redirectToHTTPS middleware that will redirect +// HTTP requests to HTTPS +func NewRedirectToHTTPS(httpsPort string) alice.Constructor { + return func(next http.Handler) http.Handler { + return redirectToHTTPS(httpsPort, next) + } +} + +// redirectToHTTPS is an HTTP middleware the will redirect a request to HTTPS +// if it is not already HTTPS. +// If the request is to a non standard port, the redirection request will be +// to the port from the httpsAddress given. +func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + proto := req.Header.Get("X-Forwarded-Proto") + if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") { + // Only care about the connection to us being HTTPS if the proto is empty, + // otherwise the proto is source of truth + next.ServeHTTP(rw, req) + return + } + + // Copy the request URL + targetURL, _ := url.Parse(req.URL.String()) + // Set the scheme to HTTPS + targetURL.Scheme = httpsScheme + + // Overwrite the port if the original request was to a non-standard port + if targetURL.Port() != "" { + // If Port was not empty, this should be fine to ignore the error + host, _, _ := net.SplitHostPort(targetURL.Host) + targetURL.Host = net.JoinHostPort(host, httpsPort) + } + + http.Redirect(rw, req, targetURL.String(), http.StatusPermanentRedirect) + }) +} diff --git a/pkg/middleware/redirect_to_https_test.go b/pkg/middleware/redirect_to_https_test.go new file mode 100644 index 00000000..c4b631b5 --- /dev/null +++ b/pkg/middleware/redirect_to_https_test.go @@ -0,0 +1,158 @@ +package middleware + +import ( + "crypto/tls" + "fmt" + "net/http/httptest" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("RedirectToHTTPS suite", func() { + const httpsPort = "8443" + + var permanentRedirectBody = func(address string) string { + return fmt.Sprintf("Permanent Redirect.\n\n", address) + } + + type requestTableInput struct { + requestString string + useTLS bool + headers map[string]string + expectedStatus int + expectedBody string + expectedLocation string + } + + DescribeTable("when serving a request", + func(in *requestTableInput) { + req := httptest.NewRequest("", in.requestString, nil) + for k, v := range in.headers { + req.Header.Add(k, v) + } + if in.useTLS { + req.TLS = &tls.ConnectionState{} + } + + rw := httptest.NewRecorder() + + handler := NewRedirectToHTTPS(httpsPort)(testHandler()) + handler.ServeHTTP(rw, req) + + Expect(rw.Code).To(Equal(in.expectedStatus)) + Expect(rw.Body.String()).To(Equal(in.expectedBody)) + + if in.expectedLocation != "" { + Expect(rw.Header().Values("Location")).To(ConsistOf(in.expectedLocation)) + } + }, + Entry("without TLS", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{}, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("with TLS", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{}, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("without TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTPS", + }, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTPS", + }, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("without TLS and X-Forwarded-Proto=https", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "https", + }, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("with TLS and X-Forwarded-Proto=https", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{ + "X-Forwarded-Proto": "https", + }, + expectedStatus: 200, + expectedBody: "test", + }), + Entry("without TLS and X-Forwarded-Proto=HTTP", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTP", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("with TLS and X-Forwarded-Proto=HTTP", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTP", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("without TLS and X-Forwarded-Proto=http", &requestTableInput{ + requestString: "https://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "http", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("with TLS and X-Forwarded-Proto=http", &requestTableInput{ + requestString: "https://example.com", + useTLS: true, + headers: map[string]string{ + "X-Forwarded-Proto": "http", + }, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), + Entry("without TLS on a non-standard port", &requestTableInput{ + requestString: "http://example.com:8080", + useTLS: false, + headers: map[string]string{}, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com:8443"), + expectedLocation: "https://example.com:8443", + }), + Entry("with TLS on a non-standard port", &requestTableInput{ + requestString: "https://example.com:8443", + useTLS: true, + headers: map[string]string{}, + expectedStatus: 200, + expectedBody: "test", + }), + ) +})