Merge pull request #620 from oauth2-proxy/healthcheck-middleware
Add HealthCheck middleware
This commit is contained in:
		
						commit
						713c3927a9
					
				| 
						 | 
				
			
			@ -55,6 +55,7 @@
 | 
			
		|||
 | 
			
		||||
## Changes since v5.1.1
 | 
			
		||||
 | 
			
		||||
- [#620](https://github.com/oauth2-proxy/oauth2-proxy/pull/620) Add HealthCheck middleware (@JoelSpeed)
 | 
			
		||||
- [#604](https://github.com/oauth2-proxy/oauth2-proxy/pull/604) Add Keycloak local testing environment (@EvgeniGordeev)
 | 
			
		||||
- [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves)
 | 
			
		||||
- [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										1
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										1
									
								
								go.mod
								
								
								
								
							| 
						 | 
				
			
			@ -12,6 +12,7 @@ require (
 | 
			
		|||
	github.com/dgrijalva/jwt-go v3.2.0+incompatible
 | 
			
		||||
	github.com/fsnotify/fsnotify v1.4.9
 | 
			
		||||
	github.com/go-redis/redis/v7 v7.2.0
 | 
			
		||||
	github.com/justinas/alice v1.2.0
 | 
			
		||||
	github.com/kr/pretty v0.2.0 // indirect
 | 
			
		||||
	github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa
 | 
			
		||||
	github.com/mitchellh/mapstructure v1.1.2
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										3
									
								
								go.sum
								
								
								
								
							
							
						
						
									
										3
									
								
								go.sum
								
								
								
								
							| 
						 | 
				
			
			@ -102,6 +102,8 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
 | 
			
		|||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
 | 
			
		||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
 | 
			
		||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
 | 
			
		||||
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
 | 
			
		||||
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
 | 
			
		||||
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
 | 
			
		||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
 | 
			
		||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 | 
			
		||||
| 
						 | 
				
			
			@ -129,6 +131,7 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
 | 
			
		|||
github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
 | 
			
		||||
github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU=
 | 
			
		||||
github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg=
 | 
			
		||||
github.com/onsi/ginkgo v1.12.3 h1:+RYp9QczoWz9zfUyLP/5SLXQVhfr6gZOoKGfQqHuLZQ=
 | 
			
		||||
github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
 | 
			
		||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
 | 
			
		||||
github.com/onsi/gomega v1.9.0 h1:R1uwffexN6Pr340GtYRIdZmAiN4J+iw6WG4wog1DUXg=
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										46
									
								
								http.go
								
								
								
								
							
							
						
						
									
										46
									
								
								http.go
								
								
								
								
							| 
						 | 
				
			
			@ -9,6 +9,7 @@ import (
 | 
			
		|||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/justinas/alice"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -29,45 +30,6 @@ func (s *Server) ListenAndServe() {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Used with gcpHealthcheck()
 | 
			
		||||
const userAgentHeader = "User-Agent"
 | 
			
		||||
const googleHealthCheckUserAgent = "GoogleHC/1.0"
 | 
			
		||||
const rootPath = "/"
 | 
			
		||||
 | 
			
		||||
// gcpHealthcheck handles healthcheck queries from GCP.
 | 
			
		||||
func gcpHealthcheck(h http.Handler) http.Handler {
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		// Check for liveness and readiness:  used for Google App Engine
 | 
			
		||||
		if r.URL.EscapedPath() == "/liveness_check" {
 | 
			
		||||
			w.WriteHeader(http.StatusOK)
 | 
			
		||||
			w.Write([]byte("OK"))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if r.URL.EscapedPath() == "/readiness_check" {
 | 
			
		||||
			w.WriteHeader(http.StatusOK)
 | 
			
		||||
			w.Write([]byte("OK"))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Check for GKE ingress healthcheck:  The ingress requires the root
 | 
			
		||||
		// path of the target to return a 200 (OK) to indicate the service's good health. This can be quite a challenging demand
 | 
			
		||||
		// depending on the application's path structure. This middleware filters out the requests from the health check by
 | 
			
		||||
		//
 | 
			
		||||
		// 1. checking that the request path is indeed the root path
 | 
			
		||||
		// 2. ensuring that the User-Agent is "GoogleHC/1.0", the health checker
 | 
			
		||||
		// 3. ensuring the request method is "GET"
 | 
			
		||||
		if r.URL.Path == rootPath &&
 | 
			
		||||
			r.Header.Get(userAgentHeader) == googleHealthCheckUserAgent &&
 | 
			
		||||
			r.Method == http.MethodGet {
 | 
			
		||||
 | 
			
		||||
			w.WriteHeader(http.StatusOK)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		h.ServeHTTP(w, r)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
 | 
			
		||||
func (s *Server) ServeHTTP() {
 | 
			
		||||
	HTTPAddress := s.Opts.HTTPAddress
 | 
			
		||||
| 
						 | 
				
			
			@ -168,6 +130,12 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
 | 
			
		|||
	return tc, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newRedirectToHTTPS(opts *options.Options) alice.Constructor {
 | 
			
		||||
	return func(next http.Handler) http.Handler {
 | 
			
		||||
		return redirectToHTTPS(opts, next)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler {
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		proto := r.Header.Get("X-Forwarded-Proto")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										99
									
								
								http_test.go
								
								
								
								
							
							
						
						
									
										99
									
								
								http_test.go
								
								
								
								
							| 
						 | 
				
			
			@ -11,105 +11,6 @@ import (
 | 
			
		|||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const localhost = "127.0.0.1"
 | 
			
		||||
const host = "test-server"
 | 
			
		||||
 | 
			
		||||
func TestGCPHealthcheckLiveness(t *testing.T) {
 | 
			
		||||
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		w.Write([]byte("test"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h := gcpHealthcheck(http.HandlerFunc(handler))
 | 
			
		||||
	rw := httptest.NewRecorder()
 | 
			
		||||
	r, _ := http.NewRequest("GET", "/liveness_check", nil)
 | 
			
		||||
	r.RemoteAddr = localhost
 | 
			
		||||
	r.Host = host
 | 
			
		||||
	h.ServeHTTP(rw, r)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, 200, rw.Code)
 | 
			
		||||
	assert.Equal(t, "OK", rw.Body.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGCPHealthcheckReadiness(t *testing.T) {
 | 
			
		||||
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		w.Write([]byte("test"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h := gcpHealthcheck(http.HandlerFunc(handler))
 | 
			
		||||
	rw := httptest.NewRecorder()
 | 
			
		||||
	r, _ := http.NewRequest("GET", "/readiness_check", nil)
 | 
			
		||||
	r.RemoteAddr = localhost
 | 
			
		||||
	r.Host = host
 | 
			
		||||
	h.ServeHTTP(rw, r)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, 200, rw.Code)
 | 
			
		||||
	assert.Equal(t, "OK", rw.Body.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGCPHealthcheckNotHealthcheck(t *testing.T) {
 | 
			
		||||
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		w.Write([]byte("test"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h := gcpHealthcheck(http.HandlerFunc(handler))
 | 
			
		||||
	rw := httptest.NewRecorder()
 | 
			
		||||
	r, _ := http.NewRequest("GET", "/not_any_check", nil)
 | 
			
		||||
	r.RemoteAddr = localhost
 | 
			
		||||
	r.Host = host
 | 
			
		||||
	h.ServeHTTP(rw, r)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, "test", rw.Body.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGCPHealthcheckIngress(t *testing.T) {
 | 
			
		||||
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		w.Write([]byte("test"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h := gcpHealthcheck(http.HandlerFunc(handler))
 | 
			
		||||
	rw := httptest.NewRecorder()
 | 
			
		||||
	r, _ := http.NewRequest("GET", "/", nil)
 | 
			
		||||
	r.RemoteAddr = localhost
 | 
			
		||||
	r.Host = host
 | 
			
		||||
	r.Header.Set(userAgentHeader, googleHealthCheckUserAgent)
 | 
			
		||||
	h.ServeHTTP(rw, r)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, 200, rw.Code)
 | 
			
		||||
	assert.Equal(t, "", rw.Body.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGCPHealthcheckNotIngress(t *testing.T) {
 | 
			
		||||
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		w.Write([]byte("test"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h := gcpHealthcheck(http.HandlerFunc(handler))
 | 
			
		||||
	rw := httptest.NewRecorder()
 | 
			
		||||
	r, _ := http.NewRequest("GET", "/foo", nil)
 | 
			
		||||
	r.RemoteAddr = localhost
 | 
			
		||||
	r.Host = host
 | 
			
		||||
	r.Header.Set(userAgentHeader, googleHealthCheckUserAgent)
 | 
			
		||||
	h.ServeHTTP(rw, r)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, "test", rw.Body.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGCPHealthcheckNotIngressPut(t *testing.T) {
 | 
			
		||||
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		w.Write([]byte("test"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h := gcpHealthcheck(http.HandlerFunc(handler))
 | 
			
		||||
	rw := httptest.NewRecorder()
 | 
			
		||||
	r, _ := http.NewRequest("PUT", "/", nil)
 | 
			
		||||
	r.RemoteAddr = localhost
 | 
			
		||||
	r.Host = host
 | 
			
		||||
	r.Header.Set(userAgentHeader, googleHealthCheckUserAgent)
 | 
			
		||||
	h.ServeHTTP(rw, r)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, "test", rw.Body.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRedirectToHTTPSTrue(t *testing.T) {
 | 
			
		||||
	opts := options.NewOptions()
 | 
			
		||||
	opts.ForceHTTPS = true
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,7 +21,6 @@ type responseLogger struct {
 | 
			
		|||
	size     int
 | 
			
		||||
	upstream string
 | 
			
		||||
	authInfo string
 | 
			
		||||
	silent   bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Header returns the ResponseWriter's Header
 | 
			
		||||
| 
						 | 
				
			
			@ -105,7 +104,5 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		|||
	url := *req.URL
 | 
			
		||||
	responseLogger := &responseLogger{w: w}
 | 
			
		||||
	h.handler.ServeHTTP(responseLogger, req)
 | 
			
		||||
	if !responseLogger.silent {
 | 
			
		||||
	logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,14 +5,11 @@ import (
 | 
			
		|||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestLoggingHandler_ServeHTTP(t *testing.T) {
 | 
			
		||||
| 
						 | 
				
			
			@ -23,20 +20,19 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
 | 
			
		|||
		ExpectedLogMessage,
 | 
			
		||||
		Path string
 | 
			
		||||
		ExcludePaths []string
 | 
			
		||||
		SilencePingLogging bool
 | 
			
		||||
	}{
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}, false},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}, true},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{"/ping"}, false},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/foo/bar", []string{"/foo/bar"}, false},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{}, true},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}, false},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}, true},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/foo/bar", "/ping"}, false},
 | 
			
		||||
		{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{}, true},
 | 
			
		||||
		{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{"/ping"}, false},
 | 
			
		||||
		{"{{.RequestMethod}}", "GET\n", "/ping", []string{}, false},
 | 
			
		||||
		{"{{.RequestMethod}}", "", "/ping", []string{"/ping"}, true},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{"/ping"}},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/foo/bar", []string{"/foo/bar"}},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{}},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}},
 | 
			
		||||
		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/foo/bar", "/ping"}},
 | 
			
		||||
		{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{}},
 | 
			
		||||
		{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{"/ping"}},
 | 
			
		||||
		{"{{.RequestMethod}}", "GET\n", "/ping", []string{}},
 | 
			
		||||
		{"{{.RequestMethod}}", "", "/ping", []string{"/ping"}},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
| 
						 | 
				
			
			@ -52,9 +48,6 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
 | 
			
		|||
 | 
			
		||||
		logger.SetOutput(buf)
 | 
			
		||||
		logger.SetReqTemplate(test.Format)
 | 
			
		||||
		if test.SilencePingLogging {
 | 
			
		||||
			test.ExcludePaths = append(test.ExcludePaths, "/ping")
 | 
			
		||||
		}
 | 
			
		||||
		logger.SetExcludePaths(test.ExcludePaths)
 | 
			
		||||
		h := LoggingHandler(http.HandlerFunc(handler))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -70,59 +63,3 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
 | 
			
		|||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestLoggingHandler_PingUserAgent(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		ExpectedLogMessage string
 | 
			
		||||
		Path               string
 | 
			
		||||
		SilencePingLogging bool
 | 
			
		||||
		WithUserAgent      string
 | 
			
		||||
	}{
 | 
			
		||||
		{"444\n", "/foo", true, "Blah"},
 | 
			
		||||
		{"444\n", "/foo", false, "Blah"},
 | 
			
		||||
		{"", "/ping", true, "Blah"},
 | 
			
		||||
		{"200\n", "/ping", false, "Blah"},
 | 
			
		||||
		{"", "/ping", true, "PingMe!"},
 | 
			
		||||
		{"", "/ping", false, "PingMe!"},
 | 
			
		||||
		{"", "/foo", true, "PingMe!"},
 | 
			
		||||
		{"", "/foo", false, "PingMe!"},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for idx, test := range tests {
 | 
			
		||||
		t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) {
 | 
			
		||||
			opts := options.NewOptions()
 | 
			
		||||
			opts.PingUserAgent = "PingMe!"
 | 
			
		||||
			opts.SkipAuthRegex = []string{"/foo"}
 | 
			
		||||
			opts.Upstreams = []string{"static://444/foo"}
 | 
			
		||||
			opts.Logging.SilencePing = test.SilencePingLogging
 | 
			
		||||
			if test.SilencePingLogging {
 | 
			
		||||
				opts.Logging.ExcludePaths = []string{"/ping"}
 | 
			
		||||
			}
 | 
			
		||||
			opts.RawRedirectURL = "localhost"
 | 
			
		||||
			validation.Validate(opts)
 | 
			
		||||
 | 
			
		||||
			p := NewOAuthProxy(opts, func(email string) bool {
 | 
			
		||||
				return true
 | 
			
		||||
			})
 | 
			
		||||
			p.provider = NewTestProvider(&url.URL{Host: "localhost"}, "")
 | 
			
		||||
 | 
			
		||||
			buf := bytes.NewBuffer(nil)
 | 
			
		||||
			logger.SetOutput(buf)
 | 
			
		||||
			logger.SetReqEnabled(true)
 | 
			
		||||
			logger.SetReqTemplate("{{.StatusCode}}")
 | 
			
		||||
 | 
			
		||||
			r, _ := http.NewRequest("GET", test.Path, nil)
 | 
			
		||||
			if test.WithUserAgent != "" {
 | 
			
		||||
				r.Header.Set("User-Agent", test.WithUserAgent)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			h := LoggingHandler(p)
 | 
			
		||||
			h.ServeHTTP(httptest.NewRecorder(), r)
 | 
			
		||||
 | 
			
		||||
			actual := buf.String()
 | 
			
		||||
			if !strings.Contains(actual, test.ExpectedLogMessage) {
 | 
			
		||||
				t.Errorf("Log message was\n%s\ninstead of matching \n%s", actual, test.ExpectedLogMessage)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										30
									
								
								main.go
								
								
								
								
							
							
						
						
									
										30
									
								
								main.go
								
								
								
								
							| 
						 | 
				
			
			@ -3,7 +3,6 @@ package main
 | 
			
		|||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/signal"
 | 
			
		||||
	"runtime"
 | 
			
		||||
| 
						 | 
				
			
			@ -11,8 +10,10 @@ import (
 | 
			
		|||
	"syscall"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/justinas/alice"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -71,14 +72,29 @@ func main() {
 | 
			
		|||
 | 
			
		||||
	rand.Seed(time.Now().UnixNano())
 | 
			
		||||
 | 
			
		||||
	var handler http.Handler
 | 
			
		||||
	if opts.GCPHealthChecks {
 | 
			
		||||
		handler = redirectToHTTPS(opts, gcpHealthcheck(LoggingHandler(oauthproxy)))
 | 
			
		||||
	} else {
 | 
			
		||||
		handler = redirectToHTTPS(opts, LoggingHandler(oauthproxy))
 | 
			
		||||
	chain := alice.New()
 | 
			
		||||
 | 
			
		||||
	if opts.ForceHTTPS {
 | 
			
		||||
		chain = chain.Append(newRedirectToHTTPS(opts))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	healthCheckPaths := []string{opts.PingPath}
 | 
			
		||||
	healthCheckUserAgents := []string{opts.PingUserAgent}
 | 
			
		||||
	if opts.GCPHealthChecks {
 | 
			
		||||
		healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check")
 | 
			
		||||
		healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// To silence logging of health checks, register the health check handler before
 | 
			
		||||
	// the logging handler
 | 
			
		||||
	if opts.Logging.SilencePing {
 | 
			
		||||
		chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler)
 | 
			
		||||
	} else {
 | 
			
		||||
		chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s := &Server{
 | 
			
		||||
		Handler: handler,
 | 
			
		||||
		Handler: chain.Then(oauthproxy),
 | 
			
		||||
		Opts:    opts,
 | 
			
		||||
		stop:    make(chan struct{}, 1),
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -81,9 +81,6 @@ type OAuthProxy struct {
 | 
			
		|||
	Validator      func(string) bool
 | 
			
		||||
 | 
			
		||||
	RobotsPath        string
 | 
			
		||||
	PingPath          string
 | 
			
		||||
	PingUserAgent     string
 | 
			
		||||
	SilencePings      bool
 | 
			
		||||
	SignInPath        string
 | 
			
		||||
	SignOutPath       string
 | 
			
		||||
	OAuthStartPath    string
 | 
			
		||||
| 
						 | 
				
			
			@ -313,9 +310,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) *OAuthPro
 | 
			
		|||
		Validator:      validator,
 | 
			
		||||
 | 
			
		||||
		RobotsPath:        "/robots.txt",
 | 
			
		||||
		PingPath:          opts.PingPath,
 | 
			
		||||
		PingUserAgent:     opts.PingUserAgent,
 | 
			
		||||
		SilencePings:      opts.Logging.SilencePing,
 | 
			
		||||
		SignInPath:        fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
 | 
			
		||||
		SignOutPath:       fmt.Sprintf("%s/sign_out", opts.ProxyPrefix),
 | 
			
		||||
		OAuthStartPath:    fmt.Sprintf("%s/start", opts.ProxyPrefix),
 | 
			
		||||
| 
						 | 
				
			
			@ -468,17 +462,6 @@ func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
 | 
			
		|||
	fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PingPage responds 200 OK to requests
 | 
			
		||||
func (p *OAuthProxy) PingPage(rw http.ResponseWriter) {
 | 
			
		||||
	if p.SilencePings {
 | 
			
		||||
		if rl, ok := rw.(*responseLogger); ok {
 | 
			
		||||
			rl.silent = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	rw.WriteHeader(http.StatusOK)
 | 
			
		||||
	fmt.Fprintf(rw, "OK")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ErrorPage writes an error response
 | 
			
		||||
func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) {
 | 
			
		||||
	rw.WriteHeader(code)
 | 
			
		||||
| 
						 | 
				
			
			@ -684,17 +667,6 @@ func prepareNoCache(w http.ResponseWriter) {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsPingRequest will check if the request appears to be performing a health check
 | 
			
		||||
// either via the path it's requesting or by a special User-Agent configuration.
 | 
			
		||||
func (p *OAuthProxy) IsPingRequest(req *http.Request) bool {
 | 
			
		||||
 | 
			
		||||
	if req.URL.EscapedPath() == p.PingPath {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return p.PingUserAgent != "" && req.Header.Get("User-Agent") == p.PingUserAgent
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	if strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
 | 
			
		||||
		prepareNoCache(rw)
 | 
			
		||||
| 
						 | 
				
			
			@ -703,8 +675,6 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 | 
			
		|||
	switch path := req.URL.Path; {
 | 
			
		||||
	case path == p.RobotsPath:
 | 
			
		||||
		p.RobotsTxt(rw)
 | 
			
		||||
	case p.IsPingRequest(req):
 | 
			
		||||
		p.PingPage(rw)
 | 
			
		||||
	case p.IsWhitelistedRequest(req):
 | 
			
		||||
		p.serveMux.ServeHTTP(rw, req)
 | 
			
		||||
	case path == p.SignInPath:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,48 @@
 | 
			
		|||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/justinas/alice"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NewHealthCheck(paths, userAgents []string) alice.Constructor {
 | 
			
		||||
	return func(next http.Handler) http.Handler {
 | 
			
		||||
		return healthCheck(paths, userAgents, next)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func healthCheck(paths, userAgents []string, next http.Handler) http.Handler {
 | 
			
		||||
	// Use a map as a set to check health check paths
 | 
			
		||||
	pathSet := make(map[string]struct{})
 | 
			
		||||
	for _, path := range paths {
 | 
			
		||||
		pathSet[path] = struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Use a map as a set to check health check paths
 | 
			
		||||
	userAgentSet := make(map[string]struct{})
 | 
			
		||||
	for _, userAgent := range userAgents {
 | 
			
		||||
		userAgentSet[userAgent] = struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | 
			
		||||
		if isHealthCheckRequest(pathSet, userAgentSet, req) {
 | 
			
		||||
			rw.WriteHeader(http.StatusOK)
 | 
			
		||||
			fmt.Fprintf(rw, "OK")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		next.ServeHTTP(rw, req)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isHealthCheckRequest(paths, userAgents map[string]struct{}, req *http.Request) bool {
 | 
			
		||||
	if _, ok := paths[req.URL.EscapedPath()]; ok {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := userAgents[req.Header.Get("User-Agent")]; ok {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,112 @@
 | 
			
		|||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
 | 
			
		||||
	. "github.com/onsi/ginkgo"
 | 
			
		||||
	. "github.com/onsi/ginkgo/extensions/table"
 | 
			
		||||
	. "github.com/onsi/gomega"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ = Describe("HealthCheck suite", func() {
 | 
			
		||||
	type requestTableInput struct {
 | 
			
		||||
		healthCheckPaths      []string
 | 
			
		||||
		healthCheckUserAgents []string
 | 
			
		||||
		requestString         string
 | 
			
		||||
		headers               map[string]string
 | 
			
		||||
		expectedStatus        int
 | 
			
		||||
		expectedBody          string
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	DescribeTable("when serving a request",
 | 
			
		||||
		func(in *requestTableInput) {
 | 
			
		||||
			req := httptest.NewRequest("", in.requestString, nil)
 | 
			
		||||
			for k, v := range in.headers {
 | 
			
		||||
				req.Header.Add(k, v)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			rw := httptest.NewRecorder()
 | 
			
		||||
 | 
			
		||||
			handler := NewHealthCheck(in.healthCheckPaths, in.healthCheckUserAgents)(http.NotFoundHandler())
 | 
			
		||||
			handler.ServeHTTP(rw, req)
 | 
			
		||||
 | 
			
		||||
			Expect(rw.Code).To(Equal(in.expectedStatus))
 | 
			
		||||
			Expect(rw.Body.String()).To(Equal(in.expectedBody))
 | 
			
		||||
		},
 | 
			
		||||
		Entry("when requesting the healthcheck path", &requestTableInput{
 | 
			
		||||
			healthCheckPaths:      []string{"/ping"},
 | 
			
		||||
			healthCheckUserAgents: []string{"hc/1.0"},
 | 
			
		||||
			requestString:         "http://example.com/ping",
 | 
			
		||||
			headers:               map[string]string{},
 | 
			
		||||
			expectedStatus:        200,
 | 
			
		||||
			expectedBody:          "OK",
 | 
			
		||||
		}),
 | 
			
		||||
		Entry("when requesting a different path", &requestTableInput{
 | 
			
		||||
			healthCheckPaths:      []string{"/ping"},
 | 
			
		||||
			healthCheckUserAgents: []string{"hc/1.0"},
 | 
			
		||||
			requestString:         "http://example.com/different",
 | 
			
		||||
			headers:               map[string]string{},
 | 
			
		||||
			expectedStatus:        404,
 | 
			
		||||
			expectedBody:          "404 page not found\n",
 | 
			
		||||
		}),
 | 
			
		||||
		Entry("with a request from the health check user agent", &requestTableInput{
 | 
			
		||||
			healthCheckPaths:      []string{"/ping"},
 | 
			
		||||
			healthCheckUserAgents: []string{"hc/1.0"},
 | 
			
		||||
			requestString:         "http://example.com/abc",
 | 
			
		||||
			headers: map[string]string{
 | 
			
		||||
				"User-Agent": "hc/1.0",
 | 
			
		||||
			},
 | 
			
		||||
			expectedStatus: 200,
 | 
			
		||||
			expectedBody:   "OK",
 | 
			
		||||
		}),
 | 
			
		||||
		Entry("with a request from a different user agent", &requestTableInput{
 | 
			
		||||
			healthCheckPaths:      []string{"/ping"},
 | 
			
		||||
			healthCheckUserAgents: []string{"hc/1.0"},
 | 
			
		||||
			requestString:         "http://example.com/abc",
 | 
			
		||||
			headers: map[string]string{
 | 
			
		||||
				"User-Agent": "different",
 | 
			
		||||
			},
 | 
			
		||||
			expectedStatus: 404,
 | 
			
		||||
			expectedBody:   "404 page not found\n",
 | 
			
		||||
		}),
 | 
			
		||||
		Entry("with multiple paths, request one of the healthcheck paths", &requestTableInput{
 | 
			
		||||
			healthCheckPaths:      []string{"/ping", "/liveness_check", "/readiness_check"},
 | 
			
		||||
			healthCheckUserAgents: []string{"hc/1.0"},
 | 
			
		||||
			requestString:         "http://example.com/readiness_check",
 | 
			
		||||
			headers:               map[string]string{},
 | 
			
		||||
			expectedStatus:        200,
 | 
			
		||||
			expectedBody:          "OK",
 | 
			
		||||
		}),
 | 
			
		||||
		Entry("with multiple paths, request none of the healthcheck paths", &requestTableInput{
 | 
			
		||||
			healthCheckPaths:      []string{"/ping", "/liveness_check", "/readiness_check"},
 | 
			
		||||
			healthCheckUserAgents: []string{"hc/1.0"},
 | 
			
		||||
			requestString:         "http://example.com/readiness",
 | 
			
		||||
			headers: map[string]string{
 | 
			
		||||
				"User-Agent": "user",
 | 
			
		||||
			},
 | 
			
		||||
			expectedStatus: 404,
 | 
			
		||||
			expectedBody:   "404 page not found\n",
 | 
			
		||||
		}),
 | 
			
		||||
		Entry("with multiple user agents, request from a health check user agent", &requestTableInput{
 | 
			
		||||
			healthCheckPaths:      []string{"/ping"},
 | 
			
		||||
			healthCheckUserAgents: []string{"hc/1.0", "GoogleHC/1.0"},
 | 
			
		||||
			requestString:         "http://example.com/abc",
 | 
			
		||||
			headers: map[string]string{
 | 
			
		||||
				"User-Agent": "GoogleHC/1.0",
 | 
			
		||||
			},
 | 
			
		||||
			expectedStatus: 200,
 | 
			
		||||
			expectedBody:   "OK",
 | 
			
		||||
		}),
 | 
			
		||||
		Entry("with multiple user agents, request from none of the health check user agents", &requestTableInput{
 | 
			
		||||
			healthCheckPaths:      []string{"/ping"},
 | 
			
		||||
			healthCheckUserAgents: []string{"hc/1.0", "GoogleHC/1.0"},
 | 
			
		||||
			requestString:         "http://example.com/abc",
 | 
			
		||||
			headers: map[string]string{
 | 
			
		||||
				"User-Agent": "user",
 | 
			
		||||
			},
 | 
			
		||||
			expectedStatus: 404,
 | 
			
		||||
			expectedBody:   "404 page not found\n",
 | 
			
		||||
		}),
 | 
			
		||||
	)
 | 
			
		||||
})
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,16 @@
 | 
			
		|||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
 | 
			
		||||
	. "github.com/onsi/ginkgo"
 | 
			
		||||
	. "github.com/onsi/gomega"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestMiddlewareSuite(t *testing.T) {
 | 
			
		||||
	logger.SetOutput(GinkgoWriter)
 | 
			
		||||
 | 
			
		||||
	RegisterFailHandler(Fail)
 | 
			
		||||
	RunSpecs(t, "Middleware")
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -9,7 +9,7 @@ import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
// configureLogger is responsible for configuring the logger based on the options given
 | 
			
		||||
func configureLogger(o options.Logging, pingPath string, msgs []string) []string {
 | 
			
		||||
func configureLogger(o options.Logging, msgs []string) []string {
 | 
			
		||||
	// Setup the log file
 | 
			
		||||
	if len(o.File.Filename) > 0 {
 | 
			
		||||
		// Validate that the file/dir can be written
 | 
			
		||||
| 
						 | 
				
			
			@ -48,11 +48,7 @@ func configureLogger(o options.Logging, pingPath string, msgs []string) []string
 | 
			
		|||
	logger.SetAuthTemplate(o.AuthFormat)
 | 
			
		||||
	logger.SetReqTemplate(o.RequestFormat)
 | 
			
		||||
 | 
			
		||||
	excludePaths := o.ExcludePaths
 | 
			
		||||
	if o.SilencePing {
 | 
			
		||||
		excludePaths = append(excludePaths, pingPath)
 | 
			
		||||
	}
 | 
			
		||||
	logger.SetExcludePaths(excludePaths)
 | 
			
		||||
	logger.SetExcludePaths(o.ExcludePaths)
 | 
			
		||||
 | 
			
		||||
	if !o.LocalTime {
 | 
			
		||||
		logger.SetFlags(logger.Flags() | logger.LUTC)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -264,7 +264,7 @@ func Validate(o *options.Options) error {
 | 
			
		|||
 | 
			
		||||
	msgs = parseSignatureKey(o, msgs)
 | 
			
		||||
	msgs = validateCookieName(o, msgs)
 | 
			
		||||
	msgs = configureLogger(o.Logging, o.PingPath, msgs)
 | 
			
		||||
	msgs = configureLogger(o.Logging, msgs)
 | 
			
		||||
 | 
			
		||||
	if o.ReverseProxy {
 | 
			
		||||
		parser, err := ip.GetRealClientIPParser(o.RealClientIPHeader)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue