Integrate HealthCheck middleware
This commit is contained in:
		
							parent
							
								
									ca416a2ebb
								
							
						
					
					
						commit
						9bbd6adce9
					
				
							
								
								
									
										46
									
								
								http.go
								
								
								
								
							
							
						
						
									
										46
									
								
								http.go
								
								
								
								
							|  | @ -9,6 +9,7 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
| ) | ) | ||||||
|  | @ -29,45 +30,6 @@ func (s *Server) ListenAndServe() { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Used with gcpHealthcheck()
 |  | ||||||
| const userAgentHeader = "User-Agent" |  | ||||||
| const googleHealthCheckUserAgent = "GoogleHC/1.0" |  | ||||||
| const rootPath = "/" |  | ||||||
| 
 |  | ||||||
| // gcpHealthcheck handles healthcheck queries from GCP.
 |  | ||||||
| func gcpHealthcheck(h http.Handler) http.Handler { |  | ||||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 		// Check for liveness and readiness:  used for Google App Engine
 |  | ||||||
| 		if r.URL.EscapedPath() == "/liveness_check" { |  | ||||||
| 			w.WriteHeader(http.StatusOK) |  | ||||||
| 			w.Write([]byte("OK")) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		if r.URL.EscapedPath() == "/readiness_check" { |  | ||||||
| 			w.WriteHeader(http.StatusOK) |  | ||||||
| 			w.Write([]byte("OK")) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Check for GKE ingress healthcheck:  The ingress requires the root
 |  | ||||||
| 		// path of the target to return a 200 (OK) to indicate the service's good health. This can be quite a challenging demand
 |  | ||||||
| 		// depending on the application's path structure. This middleware filters out the requests from the health check by
 |  | ||||||
| 		//
 |  | ||||||
| 		// 1. checking that the request path is indeed the root path
 |  | ||||||
| 		// 2. ensuring that the User-Agent is "GoogleHC/1.0", the health checker
 |  | ||||||
| 		// 3. ensuring the request method is "GET"
 |  | ||||||
| 		if r.URL.Path == rootPath && |  | ||||||
| 			r.Header.Get(userAgentHeader) == googleHealthCheckUserAgent && |  | ||||||
| 			r.Method == http.MethodGet { |  | ||||||
| 
 |  | ||||||
| 			w.WriteHeader(http.StatusOK) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		h.ServeHTTP(w, r) |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ServeHTTP constructs a net.Listener and starts handling HTTP requests
 | // ServeHTTP constructs a net.Listener and starts handling HTTP requests
 | ||||||
| func (s *Server) ServeHTTP() { | func (s *Server) ServeHTTP() { | ||||||
| 	HTTPAddress := s.Opts.HTTPAddress | 	HTTPAddress := s.Opts.HTTPAddress | ||||||
|  | @ -168,6 +130,12 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { | ||||||
| 	return tc, nil | 	return tc, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func newRedirectToHTTPS(opts *options.Options) alice.Constructor { | ||||||
|  | 	return func(next http.Handler) http.Handler { | ||||||
|  | 		return redirectToHTTPS(opts, next) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler { | func redirectToHTTPS(opts *options.Options, h http.Handler) http.Handler { | ||||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		proto := r.Header.Get("X-Forwarded-Proto") | 		proto := r.Header.Get("X-Forwarded-Proto") | ||||||
|  |  | ||||||
							
								
								
									
										99
									
								
								http_test.go
								
								
								
								
							
							
						
						
									
										99
									
								
								http_test.go
								
								
								
								
							|  | @ -11,105 +11,6 @@ import ( | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const localhost = "127.0.0.1" |  | ||||||
| const host = "test-server" |  | ||||||
| 
 |  | ||||||
| func TestGCPHealthcheckLiveness(t *testing.T) { |  | ||||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { |  | ||||||
| 		w.Write([]byte("test")) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	h := gcpHealthcheck(http.HandlerFunc(handler)) |  | ||||||
| 	rw := httptest.NewRecorder() |  | ||||||
| 	r, _ := http.NewRequest("GET", "/liveness_check", nil) |  | ||||||
| 	r.RemoteAddr = localhost |  | ||||||
| 	r.Host = host |  | ||||||
| 	h.ServeHTTP(rw, r) |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, 200, rw.Code) |  | ||||||
| 	assert.Equal(t, "OK", rw.Body.String()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestGCPHealthcheckReadiness(t *testing.T) { |  | ||||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { |  | ||||||
| 		w.Write([]byte("test")) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	h := gcpHealthcheck(http.HandlerFunc(handler)) |  | ||||||
| 	rw := httptest.NewRecorder() |  | ||||||
| 	r, _ := http.NewRequest("GET", "/readiness_check", nil) |  | ||||||
| 	r.RemoteAddr = localhost |  | ||||||
| 	r.Host = host |  | ||||||
| 	h.ServeHTTP(rw, r) |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, 200, rw.Code) |  | ||||||
| 	assert.Equal(t, "OK", rw.Body.String()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestGCPHealthcheckNotHealthcheck(t *testing.T) { |  | ||||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { |  | ||||||
| 		w.Write([]byte("test")) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	h := gcpHealthcheck(http.HandlerFunc(handler)) |  | ||||||
| 	rw := httptest.NewRecorder() |  | ||||||
| 	r, _ := http.NewRequest("GET", "/not_any_check", nil) |  | ||||||
| 	r.RemoteAddr = localhost |  | ||||||
| 	r.Host = host |  | ||||||
| 	h.ServeHTTP(rw, r) |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, "test", rw.Body.String()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestGCPHealthcheckIngress(t *testing.T) { |  | ||||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { |  | ||||||
| 		w.Write([]byte("test")) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	h := gcpHealthcheck(http.HandlerFunc(handler)) |  | ||||||
| 	rw := httptest.NewRecorder() |  | ||||||
| 	r, _ := http.NewRequest("GET", "/", nil) |  | ||||||
| 	r.RemoteAddr = localhost |  | ||||||
| 	r.Host = host |  | ||||||
| 	r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) |  | ||||||
| 	h.ServeHTTP(rw, r) |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, 200, rw.Code) |  | ||||||
| 	assert.Equal(t, "", rw.Body.String()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestGCPHealthcheckNotIngress(t *testing.T) { |  | ||||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { |  | ||||||
| 		w.Write([]byte("test")) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	h := gcpHealthcheck(http.HandlerFunc(handler)) |  | ||||||
| 	rw := httptest.NewRecorder() |  | ||||||
| 	r, _ := http.NewRequest("GET", "/foo", nil) |  | ||||||
| 	r.RemoteAddr = localhost |  | ||||||
| 	r.Host = host |  | ||||||
| 	r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) |  | ||||||
| 	h.ServeHTTP(rw, r) |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, "test", rw.Body.String()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestGCPHealthcheckNotIngressPut(t *testing.T) { |  | ||||||
| 	handler := func(w http.ResponseWriter, req *http.Request) { |  | ||||||
| 		w.Write([]byte("test")) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	h := gcpHealthcheck(http.HandlerFunc(handler)) |  | ||||||
| 	rw := httptest.NewRecorder() |  | ||||||
| 	r, _ := http.NewRequest("PUT", "/", nil) |  | ||||||
| 	r.RemoteAddr = localhost |  | ||||||
| 	r.Host = host |  | ||||||
| 	r.Header.Set(userAgentHeader, googleHealthCheckUserAgent) |  | ||||||
| 	h.ServeHTTP(rw, r) |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, "test", rw.Body.String()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestRedirectToHTTPSTrue(t *testing.T) { | func TestRedirectToHTTPSTrue(t *testing.T) { | ||||||
| 	opts := options.NewOptions() | 	opts := options.NewOptions() | ||||||
| 	opts.ForceHTTPS = true | 	opts.ForceHTTPS = true | ||||||
|  |  | ||||||
|  | @ -21,7 +21,6 @@ type responseLogger struct { | ||||||
| 	size     int | 	size     int | ||||||
| 	upstream string | 	upstream string | ||||||
| 	authInfo string | 	authInfo string | ||||||
| 	silent   bool |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Header returns the ResponseWriter's Header
 | // Header returns the ResponseWriter's Header
 | ||||||
|  | @ -105,7 +104,5 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { | ||||||
| 	url := *req.URL | 	url := *req.URL | ||||||
| 	responseLogger := &responseLogger{w: w} | 	responseLogger := &responseLogger{w: w} | ||||||
| 	h.handler.ServeHTTP(responseLogger, req) | 	h.handler.ServeHTTP(responseLogger, req) | ||||||
| 	if !responseLogger.silent { |  | ||||||
| 	logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) | 	logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -5,14 +5,11 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestLoggingHandler_ServeHTTP(t *testing.T) { | func TestLoggingHandler_ServeHTTP(t *testing.T) { | ||||||
|  | @ -23,20 +20,19 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { | ||||||
| 		ExpectedLogMessage, | 		ExpectedLogMessage, | ||||||
| 		Path string | 		Path string | ||||||
| 		ExcludePaths []string | 		ExcludePaths []string | ||||||
| 		SilencePingLogging bool |  | ||||||
| 	}{ | 	}{ | ||||||
| 		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}, false}, | 		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}}, | ||||||
| 		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}, true}, | 		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{}}, | ||||||
| 		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{"/ping"}, false}, | 		{logger.DefaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", logger.FormatTimestamp(ts)), "/foo/bar", []string{"/ping"}}, | ||||||
| 		{logger.DefaultRequestLoggingFormat, "", "/foo/bar", []string{"/foo/bar"}, false}, | 		{logger.DefaultRequestLoggingFormat, "", "/foo/bar", []string{"/foo/bar"}}, | ||||||
| 		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{}, true}, | 		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{}}, | ||||||
| 		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}, false}, | 		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}}, | ||||||
| 		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}, true}, | 		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/ping"}}, | ||||||
| 		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/foo/bar", "/ping"}, false}, | 		{logger.DefaultRequestLoggingFormat, "", "/ping", []string{"/foo/bar", "/ping"}}, | ||||||
| 		{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{}, true}, | 		{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{}}, | ||||||
| 		{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{"/ping"}, false}, | 		{"{{.RequestMethod}}", "GET\n", "/foo/bar", []string{"/ping"}}, | ||||||
| 		{"{{.RequestMethod}}", "GET\n", "/ping", []string{}, false}, | 		{"{{.RequestMethod}}", "GET\n", "/ping", []string{}}, | ||||||
| 		{"{{.RequestMethod}}", "", "/ping", []string{"/ping"}, true}, | 		{"{{.RequestMethod}}", "", "/ping", []string{"/ping"}}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, test := range tests { | 	for _, test := range tests { | ||||||
|  | @ -52,9 +48,6 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 		logger.SetOutput(buf) | 		logger.SetOutput(buf) | ||||||
| 		logger.SetReqTemplate(test.Format) | 		logger.SetReqTemplate(test.Format) | ||||||
| 		if test.SilencePingLogging { |  | ||||||
| 			test.ExcludePaths = append(test.ExcludePaths, "/ping") |  | ||||||
| 		} |  | ||||||
| 		logger.SetExcludePaths(test.ExcludePaths) | 		logger.SetExcludePaths(test.ExcludePaths) | ||||||
| 		h := LoggingHandler(http.HandlerFunc(handler)) | 		h := LoggingHandler(http.HandlerFunc(handler)) | ||||||
| 
 | 
 | ||||||
|  | @ -70,59 +63,3 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func TestLoggingHandler_PingUserAgent(t *testing.T) { |  | ||||||
| 	tests := []struct { |  | ||||||
| 		ExpectedLogMessage string |  | ||||||
| 		Path               string |  | ||||||
| 		SilencePingLogging bool |  | ||||||
| 		WithUserAgent      string |  | ||||||
| 	}{ |  | ||||||
| 		{"444\n", "/foo", true, "Blah"}, |  | ||||||
| 		{"444\n", "/foo", false, "Blah"}, |  | ||||||
| 		{"", "/ping", true, "Blah"}, |  | ||||||
| 		{"200\n", "/ping", false, "Blah"}, |  | ||||||
| 		{"", "/ping", true, "PingMe!"}, |  | ||||||
| 		{"", "/ping", false, "PingMe!"}, |  | ||||||
| 		{"", "/foo", true, "PingMe!"}, |  | ||||||
| 		{"", "/foo", false, "PingMe!"}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for idx, test := range tests { |  | ||||||
| 		t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) { |  | ||||||
| 			opts := options.NewOptions() |  | ||||||
| 			opts.PingUserAgent = "PingMe!" |  | ||||||
| 			opts.SkipAuthRegex = []string{"/foo"} |  | ||||||
| 			opts.Upstreams = []string{"static://444/foo"} |  | ||||||
| 			opts.Logging.SilencePing = test.SilencePingLogging |  | ||||||
| 			if test.SilencePingLogging { |  | ||||||
| 				opts.Logging.ExcludePaths = []string{"/ping"} |  | ||||||
| 			} |  | ||||||
| 			opts.RawRedirectURL = "localhost" |  | ||||||
| 			validation.Validate(opts) |  | ||||||
| 
 |  | ||||||
| 			p := NewOAuthProxy(opts, func(email string) bool { |  | ||||||
| 				return true |  | ||||||
| 			}) |  | ||||||
| 			p.provider = NewTestProvider(&url.URL{Host: "localhost"}, "") |  | ||||||
| 
 |  | ||||||
| 			buf := bytes.NewBuffer(nil) |  | ||||||
| 			logger.SetOutput(buf) |  | ||||||
| 			logger.SetReqEnabled(true) |  | ||||||
| 			logger.SetReqTemplate("{{.StatusCode}}") |  | ||||||
| 
 |  | ||||||
| 			r, _ := http.NewRequest("GET", test.Path, nil) |  | ||||||
| 			if test.WithUserAgent != "" { |  | ||||||
| 				r.Header.Set("User-Agent", test.WithUserAgent) |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			h := LoggingHandler(p) |  | ||||||
| 			h.ServeHTTP(httptest.NewRecorder(), r) |  | ||||||
| 
 |  | ||||||
| 			actual := buf.String() |  | ||||||
| 			if !strings.Contains(actual, test.ExpectedLogMessage) { |  | ||||||
| 				t.Errorf("Log message was\n%s\ninstead of matching \n%s", actual, test.ExpectedLogMessage) |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
							
								
								
									
										30
									
								
								main.go
								
								
								
								
							
							
						
						
									
										30
									
								
								main.go
								
								
								
								
							|  | @ -3,7 +3,6 @@ package main | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"net/http" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/signal" | 	"os/signal" | ||||||
| 	"runtime" | 	"runtime" | ||||||
|  | @ -11,8 +10,10 @@ import ( | ||||||
| 	"syscall" | 	"syscall" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -71,14 +72,29 @@ func main() { | ||||||
| 
 | 
 | ||||||
| 	rand.Seed(time.Now().UnixNano()) | 	rand.Seed(time.Now().UnixNano()) | ||||||
| 
 | 
 | ||||||
| 	var handler http.Handler | 	chain := alice.New() | ||||||
| 	if opts.GCPHealthChecks { | 
 | ||||||
| 		handler = redirectToHTTPS(opts, gcpHealthcheck(LoggingHandler(oauthproxy))) | 	if opts.ForceHTTPS { | ||||||
| 	} else { | 		chain = chain.Append(newRedirectToHTTPS(opts)) | ||||||
| 		handler = redirectToHTTPS(opts, LoggingHandler(oauthproxy)) |  | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	healthCheckPaths := []string{opts.PingPath} | ||||||
|  | 	healthCheckUserAgents := []string{opts.PingUserAgent} | ||||||
|  | 	if opts.GCPHealthChecks { | ||||||
|  | 		healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check") | ||||||
|  | 		healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// To silence logging of health checks, register the health check handler before
 | ||||||
|  | 	// the logging handler
 | ||||||
|  | 	if opts.Logging.SilencePing { | ||||||
|  | 		chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler) | ||||||
|  | 	} else { | ||||||
|  | 		chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	s := &Server{ | 	s := &Server{ | ||||||
| 		Handler: handler, | 		Handler: chain.Then(oauthproxy), | ||||||
| 		Opts:    opts, | 		Opts:    opts, | ||||||
| 		stop:    make(chan struct{}, 1), | 		stop:    make(chan struct{}, 1), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -81,9 +81,6 @@ type OAuthProxy struct { | ||||||
| 	Validator      func(string) bool | 	Validator      func(string) bool | ||||||
| 
 | 
 | ||||||
| 	RobotsPath        string | 	RobotsPath        string | ||||||
| 	PingPath          string |  | ||||||
| 	PingUserAgent     string |  | ||||||
| 	SilencePings      bool |  | ||||||
| 	SignInPath        string | 	SignInPath        string | ||||||
| 	SignOutPath       string | 	SignOutPath       string | ||||||
| 	OAuthStartPath    string | 	OAuthStartPath    string | ||||||
|  | @ -313,9 +310,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) *OAuthPro | ||||||
| 		Validator:      validator, | 		Validator:      validator, | ||||||
| 
 | 
 | ||||||
| 		RobotsPath:        "/robots.txt", | 		RobotsPath:        "/robots.txt", | ||||||
| 		PingPath:          opts.PingPath, |  | ||||||
| 		PingUserAgent:     opts.PingUserAgent, |  | ||||||
| 		SilencePings:      opts.Logging.SilencePing, |  | ||||||
| 		SignInPath:        fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), | 		SignInPath:        fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), | ||||||
| 		SignOutPath:       fmt.Sprintf("%s/sign_out", opts.ProxyPrefix), | 		SignOutPath:       fmt.Sprintf("%s/sign_out", opts.ProxyPrefix), | ||||||
| 		OAuthStartPath:    fmt.Sprintf("%s/start", opts.ProxyPrefix), | 		OAuthStartPath:    fmt.Sprintf("%s/start", opts.ProxyPrefix), | ||||||
|  | @ -468,17 +462,6 @@ func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { | ||||||
| 	fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | 	fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // PingPage responds 200 OK to requests
 |  | ||||||
| func (p *OAuthProxy) PingPage(rw http.ResponseWriter) { |  | ||||||
| 	if p.SilencePings { |  | ||||||
| 		if rl, ok := rw.(*responseLogger); ok { |  | ||||||
| 			rl.silent = true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	rw.WriteHeader(http.StatusOK) |  | ||||||
| 	fmt.Fprintf(rw, "OK") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ErrorPage writes an error response
 | // ErrorPage writes an error response
 | ||||||
| func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { | func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
|  | @ -684,17 +667,6 @@ func prepareNoCache(w http.ResponseWriter) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // IsPingRequest will check if the request appears to be performing a health check
 |  | ||||||
| // either via the path it's requesting or by a special User-Agent configuration.
 |  | ||||||
| func (p *OAuthProxy) IsPingRequest(req *http.Request) bool { |  | ||||||
| 
 |  | ||||||
| 	if req.URL.EscapedPath() == p.PingPath { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return p.PingUserAgent != "" && req.Header.Get("User-Agent") == p.PingUserAgent |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	if strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { | 	if strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { | ||||||
| 		prepareNoCache(rw) | 		prepareNoCache(rw) | ||||||
|  | @ -703,8 +675,6 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	switch path := req.URL.Path; { | 	switch path := req.URL.Path; { | ||||||
| 	case path == p.RobotsPath: | 	case path == p.RobotsPath: | ||||||
| 		p.RobotsTxt(rw) | 		p.RobotsTxt(rw) | ||||||
| 	case p.IsPingRequest(req): |  | ||||||
| 		p.PingPage(rw) |  | ||||||
| 	case p.IsWhitelistedRequest(req): | 	case p.IsWhitelistedRequest(req): | ||||||
| 		p.serveMux.ServeHTTP(rw, req) | 		p.serveMux.ServeHTTP(rw, req) | ||||||
| 	case path == p.SignInPath: | 	case path == p.SignInPath: | ||||||
|  |  | ||||||
|  | @ -9,7 +9,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // configureLogger is responsible for configuring the logger based on the options given
 | // configureLogger is responsible for configuring the logger based on the options given
 | ||||||
| func configureLogger(o options.Logging, pingPath string, msgs []string) []string { | func configureLogger(o options.Logging, msgs []string) []string { | ||||||
| 	// Setup the log file
 | 	// Setup the log file
 | ||||||
| 	if len(o.File.Filename) > 0 { | 	if len(o.File.Filename) > 0 { | ||||||
| 		// Validate that the file/dir can be written
 | 		// Validate that the file/dir can be written
 | ||||||
|  | @ -48,11 +48,7 @@ func configureLogger(o options.Logging, pingPath string, msgs []string) []string | ||||||
| 	logger.SetAuthTemplate(o.AuthFormat) | 	logger.SetAuthTemplate(o.AuthFormat) | ||||||
| 	logger.SetReqTemplate(o.RequestFormat) | 	logger.SetReqTemplate(o.RequestFormat) | ||||||
| 
 | 
 | ||||||
| 	excludePaths := o.ExcludePaths | 	logger.SetExcludePaths(o.ExcludePaths) | ||||||
| 	if o.SilencePing { |  | ||||||
| 		excludePaths = append(excludePaths, pingPath) |  | ||||||
| 	} |  | ||||||
| 	logger.SetExcludePaths(excludePaths) |  | ||||||
| 
 | 
 | ||||||
| 	if !o.LocalTime { | 	if !o.LocalTime { | ||||||
| 		logger.SetFlags(logger.Flags() | logger.LUTC) | 		logger.SetFlags(logger.Flags() | logger.LUTC) | ||||||
|  |  | ||||||
|  | @ -264,7 +264,7 @@ func Validate(o *options.Options) error { | ||||||
| 
 | 
 | ||||||
| 	msgs = parseSignatureKey(o, msgs) | 	msgs = parseSignatureKey(o, msgs) | ||||||
| 	msgs = validateCookieName(o, msgs) | 	msgs = validateCookieName(o, msgs) | ||||||
| 	msgs = configureLogger(o.Logging, o.PingPath, msgs) | 	msgs = configureLogger(o.Logging, msgs) | ||||||
| 
 | 
 | ||||||
| 	if o.ReverseProxy { | 	if o.ReverseProxy { | ||||||
| 		parser, err := ip.GetRealClientIPParser(o.RealClientIPHeader) | 		parser, err := ip.GetRealClientIPParser(o.RealClientIPHeader) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue