Merge pull request #619 from oauth2-proxy/https-redirect-middleware
Improve Redirect to HTTPs behaviour
This commit is contained in:
		
						commit
						c4cf15f3e1
					
				|  | @ -8,6 +8,7 @@ | |||
| 
 | ||||
| ## Changes since v6.0.0 | ||||
| 
 | ||||
| - [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed) | ||||
| - [#654](https://github.com/oauth2-proxy/oauth2-proxy/pull/654) Close client connections after each redis test (@JoelSpeed) | ||||
| - [#542](https://github.com/oauth2-proxy/oauth2-proxy/pull/542) Move SessionStore tests to independent package (@JoelSpeed) | ||||
| - [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed) | ||||
|  |  | |||
							
								
								
									
										18
									
								
								http.go
								
								
								
								
							
							
						
						
									
										18
									
								
								http.go
								
								
								
								
							|  | @ -9,7 +9,6 @@ import ( | |||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| ) | ||||
|  | @ -129,20 +128,3 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { | |||
| 	tc.SetKeepAlivePeriod(3 * time.Minute) | ||||
| 	return tc, nil | ||||
| } | ||||
| 
 | ||||
| func newRedirectToHTTPS(opts *options.Options) alice.Constructor { | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return redirectToHTTPS(opts, next) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		proto := r.Header.Get("X-Forwarded-Proto") | ||||
| 		if opts.ForceHTTPS && (r.TLS == nil || (proto != "" && strings.ToLower(proto) != "https")) { | ||||
| 			http.Redirect(w, r, opts.HTTPSAddress, http.StatusPermanentRedirect) | ||||
| 		} | ||||
| 
 | ||||
| 		h.ServeHTTP(w, r) | ||||
| 	}) | ||||
| } | ||||
|  |  | |||
							
								
								
									
										51
									
								
								http_test.go
								
								
								
								
							
							
						
						
									
										51
									
								
								http_test.go
								
								
								
								
							|  | @ -2,7 +2,6 @@ package main | |||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | @ -11,56 +10,6 @@ import ( | |||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| func TestRedirectToHTTPSTrue(t *testing.T) { | ||||
| 	opts := options.NewOptions() | ||||
| 	opts.ForceHTTPS = true | ||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { | ||||
| 		w.Write([]byte("test")) | ||||
| 	} | ||||
| 
 | ||||
| 	h := redirectToHTTPS(opts, http.HandlerFunc(handler)) | ||||
| 	rw := httptest.NewRecorder() | ||||
| 	r, _ := http.NewRequest("GET", "/", nil) | ||||
| 	h.ServeHTTP(rw, r) | ||||
| 
 | ||||
| 	assert.Equal(t, http.StatusPermanentRedirect, rw.Code, "status code should be %d, got: %d", http.StatusPermanentRedirect, rw.Code) | ||||
| } | ||||
| 
 | ||||
| func TestRedirectToHTTPSFalse(t *testing.T) { | ||||
| 	opts := options.NewOptions() | ||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { | ||||
| 		w.Write([]byte("test")) | ||||
| 	} | ||||
| 
 | ||||
| 	h := redirectToHTTPS(opts, http.HandlerFunc(handler)) | ||||
| 	rw := httptest.NewRecorder() | ||||
| 	r, _ := http.NewRequest("GET", "/", nil) | ||||
| 	h.ServeHTTP(rw, r) | ||||
| 
 | ||||
| 	assert.Equal(t, http.StatusOK, rw.Code, "status code should be %d, got: %d", http.StatusOK, rw.Code) | ||||
| } | ||||
| 
 | ||||
| func TestRedirectNotWhenHTTPS(t *testing.T) { | ||||
| 	opts := options.NewOptions() | ||||
| 	opts.ForceHTTPS = true | ||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { | ||||
| 		w.Write([]byte("test")) | ||||
| 	} | ||||
| 
 | ||||
| 	h := redirectToHTTPS(opts, http.HandlerFunc(handler)) | ||||
| 	s := httptest.NewTLSServer(h) | ||||
| 	defer s.Close() | ||||
| 
 | ||||
| 	opts.HTTPSAddress = s.URL | ||||
| 	client := s.Client() | ||||
| 	res, err := client.Get(s.URL) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("request to test server failed with error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	assert.Equal(t, http.StatusOK, res.StatusCode, "status code should be %d, got: %d", http.StatusOK, res.StatusCode) | ||||
| } | ||||
| 
 | ||||
| func TestGracefulShutdown(t *testing.T) { | ||||
| 	opts := options.NewOptions() | ||||
| 	stop := make(chan struct{}, 1) | ||||
|  |  | |||
							
								
								
									
										7
									
								
								main.go
								
								
								
								
							
							
						
						
									
										7
									
								
								main.go
								
								
								
								
							|  | @ -3,6 +3,7 @@ package main | |||
| import ( | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"runtime" | ||||
|  | @ -79,7 +80,11 @@ func main() { | |||
| 	chain := alice.New() | ||||
| 
 | ||||
| 	if opts.ForceHTTPS { | ||||
| 		chain = chain.Append(newRedirectToHTTPS(opts)) | ||||
| 		_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) | ||||
| 		if err != nil { | ||||
| 			logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err) | ||||
| 		} | ||||
| 		chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) | ||||
| 	} | ||||
| 
 | ||||
| 	healthCheckPaths := []string{opts.PingPath} | ||||
|  |  | |||
|  | @ -1,6 +1,7 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
|  | @ -14,3 +15,9 @@ func TestMiddlewareSuite(t *testing.T) { | |||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "Middleware") | ||||
| } | ||||
| 
 | ||||
| func testHandler() http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		rw.Write([]byte("test")) | ||||
| 	}) | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,50 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| ) | ||||
| 
 | ||||
| const httpsScheme = "https" | ||||
| 
 | ||||
| // NewRedirectToHTTPS creates a new redirectToHTTPS middleware that will redirect
 | ||||
| // HTTP requests to HTTPS
 | ||||
| func NewRedirectToHTTPS(httpsPort string) alice.Constructor { | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return redirectToHTTPS(httpsPort, next) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // redirectToHTTPS is an HTTP middleware the will redirect a request to HTTPS
 | ||||
| // if it is not already HTTPS.
 | ||||
| // If the request is to a non standard port, the redirection request will be
 | ||||
| // to the port from the httpsAddress given.
 | ||||
| func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		proto := req.Header.Get("X-Forwarded-Proto") | ||||
| 		if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") { | ||||
| 			// Only care about the connection to us being HTTPS if the proto is empty,
 | ||||
| 			// otherwise the proto is source of truth
 | ||||
| 			next.ServeHTTP(rw, req) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// Copy the request URL
 | ||||
| 		targetURL, _ := url.Parse(req.URL.String()) | ||||
| 		// Set the scheme to HTTPS
 | ||||
| 		targetURL.Scheme = httpsScheme | ||||
| 
 | ||||
| 		// Overwrite the port if the original request was to a non-standard port
 | ||||
| 		if targetURL.Port() != "" { | ||||
| 			// If Port was not empty, this should be fine to ignore the error
 | ||||
| 			host, _, _ := net.SplitHostPort(targetURL.Host) | ||||
| 			targetURL.Host = net.JoinHostPort(host, httpsPort) | ||||
| 		} | ||||
| 
 | ||||
| 		http.Redirect(rw, req, targetURL.String(), http.StatusPermanentRedirect) | ||||
| 	}) | ||||
| } | ||||
|  | @ -0,0 +1,158 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"net/http/httptest" | ||||
| 
 | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("RedirectToHTTPS suite", func() { | ||||
| 	const httpsPort = "8443" | ||||
| 
 | ||||
| 	var permanentRedirectBody = func(address string) string { | ||||
| 		return fmt.Sprintf("<a href=\"%s\">Permanent Redirect</a>.\n\n", address) | ||||
| 	} | ||||
| 
 | ||||
| 	type requestTableInput struct { | ||||
| 		requestString    string | ||||
| 		useTLS           bool | ||||
| 		headers          map[string]string | ||||
| 		expectedStatus   int | ||||
| 		expectedBody     string | ||||
| 		expectedLocation string | ||||
| 	} | ||||
| 
 | ||||
| 	DescribeTable("when serving a request", | ||||
| 		func(in *requestTableInput) { | ||||
| 			req := httptest.NewRequest("", in.requestString, nil) | ||||
| 			for k, v := range in.headers { | ||||
| 				req.Header.Add(k, v) | ||||
| 			} | ||||
| 			if in.useTLS { | ||||
| 				req.TLS = &tls.ConnectionState{} | ||||
| 			} | ||||
| 
 | ||||
| 			rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 			handler := NewRedirectToHTTPS(httpsPort)(testHandler()) | ||||
| 			handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 			Expect(rw.Code).To(Equal(in.expectedStatus)) | ||||
| 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | ||||
| 
 | ||||
| 			if in.expectedLocation != "" { | ||||
| 				Expect(rw.Header().Values("Location")).To(ConsistOf(in.expectedLocation)) | ||||
| 			} | ||||
| 		}, | ||||
| 		Entry("without TLS", &requestTableInput{ | ||||
| 			requestString:    "http://example.com", | ||||
| 			useTLS:           false, | ||||
| 			headers:          map[string]string{}, | ||||
| 			expectedStatus:   308, | ||||
| 			expectedBody:     permanentRedirectBody("https://example.com"), | ||||
| 			expectedLocation: "https://example.com", | ||||
| 		}), | ||||
| 		Entry("with TLS", &requestTableInput{ | ||||
| 			requestString:  "https://example.com", | ||||
| 			useTLS:         true, | ||||
| 			headers:        map[string]string{}, | ||||
| 			expectedStatus: 200, | ||||
| 			expectedBody:   "test", | ||||
| 		}), | ||||
| 		Entry("without TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{ | ||||
| 			requestString: "http://example.com", | ||||
| 			useTLS:        false, | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-Proto": "HTTPS", | ||||
| 			}, | ||||
| 			expectedStatus: 200, | ||||
| 			expectedBody:   "test", | ||||
| 		}), | ||||
| 		Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{ | ||||
| 			requestString: "https://example.com", | ||||
| 			useTLS:        true, | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-Proto": "HTTPS", | ||||
| 			}, | ||||
| 			expectedStatus: 200, | ||||
| 			expectedBody:   "test", | ||||
| 		}), | ||||
| 		Entry("without TLS and X-Forwarded-Proto=https", &requestTableInput{ | ||||
| 			requestString: "http://example.com", | ||||
| 			useTLS:        false, | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-Proto": "https", | ||||
| 			}, | ||||
| 			expectedStatus: 200, | ||||
| 			expectedBody:   "test", | ||||
| 		}), | ||||
| 		Entry("with TLS and X-Forwarded-Proto=https", &requestTableInput{ | ||||
| 			requestString: "https://example.com", | ||||
| 			useTLS:        true, | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-Proto": "https", | ||||
| 			}, | ||||
| 			expectedStatus: 200, | ||||
| 			expectedBody:   "test", | ||||
| 		}), | ||||
| 		Entry("without TLS and X-Forwarded-Proto=HTTP", &requestTableInput{ | ||||
| 			requestString: "http://example.com", | ||||
| 			useTLS:        false, | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-Proto": "HTTP", | ||||
| 			}, | ||||
| 			expectedStatus:   308, | ||||
| 			expectedBody:     permanentRedirectBody("https://example.com"), | ||||
| 			expectedLocation: "https://example.com", | ||||
| 		}), | ||||
| 		Entry("with TLS and X-Forwarded-Proto=HTTP", &requestTableInput{ | ||||
| 			requestString: "https://example.com", | ||||
| 			useTLS:        true, | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-Proto": "HTTP", | ||||
| 			}, | ||||
| 			expectedStatus:   308, | ||||
| 			expectedBody:     permanentRedirectBody("https://example.com"), | ||||
| 			expectedLocation: "https://example.com", | ||||
| 		}), | ||||
| 		Entry("without TLS and X-Forwarded-Proto=http", &requestTableInput{ | ||||
| 			requestString: "https://example.com", | ||||
| 			useTLS:        false, | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-Proto": "http", | ||||
| 			}, | ||||
| 			expectedStatus:   308, | ||||
| 			expectedBody:     permanentRedirectBody("https://example.com"), | ||||
| 			expectedLocation: "https://example.com", | ||||
| 		}), | ||||
| 		Entry("with TLS and X-Forwarded-Proto=http", &requestTableInput{ | ||||
| 			requestString: "https://example.com", | ||||
| 			useTLS:        true, | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-Proto": "http", | ||||
| 			}, | ||||
| 			expectedStatus:   308, | ||||
| 			expectedBody:     permanentRedirectBody("https://example.com"), | ||||
| 			expectedLocation: "https://example.com", | ||||
| 		}), | ||||
| 		Entry("without TLS on a non-standard port", &requestTableInput{ | ||||
| 			requestString:    "http://example.com:8080", | ||||
| 			useTLS:           false, | ||||
| 			headers:          map[string]string{}, | ||||
| 			expectedStatus:   308, | ||||
| 			expectedBody:     permanentRedirectBody("https://example.com:8443"), | ||||
| 			expectedLocation: "https://example.com:8443", | ||||
| 		}), | ||||
| 		Entry("with TLS on a non-standard port", &requestTableInput{ | ||||
| 			requestString:  "https://example.com:8443", | ||||
| 			useTLS:         true, | ||||
| 			headers:        map[string]string{}, | ||||
| 			expectedStatus: 200, | ||||
| 			expectedBody:   "test", | ||||
| 		}), | ||||
| 	) | ||||
| }) | ||||
		Loading…
	
		Reference in New Issue