Move Logging to Middleware Package (#1070)
* Use a specialized ResponseWriter in middleware * Track User & Upstream in RequestScope * Wrap responses in our custom ResponseWriter * Add tests for logging middleware * Inject upstream metadata into request scope * Use custom ResponseWriter only in logging middleware * Assume RequestScope is never nil
This commit is contained in:
		
							parent
							
								
									220b3708fc
								
							
						
					
					
						commit
						602dac7852
					
				|  | @ -8,6 +8,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v7.0.1 | ## Changes since v7.0.1 | ||||||
| 
 | 
 | ||||||
|  | - [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves) | ||||||
| - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) | - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) | ||||||
| - [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) | - [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) | ||||||
| - [#1054](https://github.com/oauth2-proxy/oauth2-proxy/pull/1054) Update to Go 1.16 (@JoelSpeed) | - [#1054](https://github.com/oauth2-proxy/oauth2-proxy/pull/1054) Update to Go 1.16 (@JoelSpeed) | ||||||
|  |  | ||||||
|  | @ -1,108 +0,0 @@ | ||||||
| // largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go
 |  | ||||||
| // to add logging of request duration as last value (and drop referrer)
 |  | ||||||
| 
 |  | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"bufio" |  | ||||||
| 	"errors" |  | ||||||
| 	"net" |  | ||||||
| 	"net/http" |  | ||||||
| 	"time" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
 |  | ||||||
| // code and body size
 |  | ||||||
| type responseLogger struct { |  | ||||||
| 	w        http.ResponseWriter |  | ||||||
| 	status   int |  | ||||||
| 	size     int |  | ||||||
| 	upstream string |  | ||||||
| 	authInfo string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Header returns the ResponseWriter's Header
 |  | ||||||
| func (l *responseLogger) Header() http.Header { |  | ||||||
| 	return l.w.Header() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Support Websocket
 |  | ||||||
| func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { |  | ||||||
| 	if hj, ok := l.w.(http.Hijacker); ok { |  | ||||||
| 		return hj.Hijack() |  | ||||||
| 	} |  | ||||||
| 	return nil, nil, errors.New("http.Hijacker is not available on writer") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
 |  | ||||||
| // Header
 |  | ||||||
| func (l *responseLogger) ExtractGAPMetadata() { |  | ||||||
| 	upstream := l.w.Header().Get("GAP-Upstream-Address") |  | ||||||
| 	if upstream != "" { |  | ||||||
| 		l.upstream = upstream |  | ||||||
| 		l.w.Header().Del("GAP-Upstream-Address") |  | ||||||
| 	} |  | ||||||
| 	authInfo := l.w.Header().Get("GAP-Auth") |  | ||||||
| 	if authInfo != "" { |  | ||||||
| 		l.authInfo = authInfo |  | ||||||
| 		l.w.Header().Del("GAP-Auth") |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Write writes the response using the ResponseWriter
 |  | ||||||
| func (l *responseLogger) Write(b []byte) (int, error) { |  | ||||||
| 	if l.status == 0 { |  | ||||||
| 		// The status will be StatusOK if WriteHeader has not been called yet
 |  | ||||||
| 		l.status = http.StatusOK |  | ||||||
| 	} |  | ||||||
| 	l.ExtractGAPMetadata() |  | ||||||
| 	size, err := l.w.Write(b) |  | ||||||
| 	l.size += size |  | ||||||
| 	return size, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // WriteHeader writes the status code for the Response
 |  | ||||||
| func (l *responseLogger) WriteHeader(s int) { |  | ||||||
| 	l.ExtractGAPMetadata() |  | ||||||
| 	l.w.WriteHeader(s) |  | ||||||
| 	l.status = s |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Status returns the response status code
 |  | ||||||
| func (l *responseLogger) Status() int { |  | ||||||
| 	return l.status |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Size returns the response size
 |  | ||||||
| func (l *responseLogger) Size() int { |  | ||||||
| 	return l.size |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Flush sends any buffered data to the client
 |  | ||||||
| func (l *responseLogger) Flush() { |  | ||||||
| 	if flusher, ok := l.w.(http.Flusher); ok { |  | ||||||
| 		flusher.Flush() |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // loggingHandler is the http.Handler implementation for LoggingHandler
 |  | ||||||
| type loggingHandler struct { |  | ||||||
| 	handler http.Handler |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // LoggingHandler provides an http.Handler which logs requests to the HTTP server
 |  | ||||||
| func LoggingHandler(h http.Handler) http.Handler { |  | ||||||
| 	return loggingHandler{ |  | ||||||
| 		handler: h, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { |  | ||||||
| 	t := time.Now() |  | ||||||
| 	url := *req.URL |  | ||||||
| 	responseLogger := &responseLogger{w: w} |  | ||||||
| 	h.handler.ServeHTTP(responseLogger, req) |  | ||||||
| 	logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) |  | ||||||
| } |  | ||||||
|  | @ -1,116 +0,0 @@ | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"testing" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" |  | ||||||
| 	"github.com/stretchr/testify/assert" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| const RequestLoggingFormatWithoutTime = "{{.Client}} - {{.Username}} [TIMELESS] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" |  | ||||||
| 
 |  | ||||||
| func TestLoggingHandler_ServeHTTP(t *testing.T) { |  | ||||||
| 	tests := []struct { |  | ||||||
| 		Format             string |  | ||||||
| 		ExpectedLogMessage string |  | ||||||
| 		Path               string |  | ||||||
| 		ExcludePaths       []string |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			Format:             RequestLoggingFormatWithoutTime, |  | ||||||
| 			ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", |  | ||||||
| 			Path:               "/foo/bar", |  | ||||||
| 			ExcludePaths:       []string{}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             RequestLoggingFormatWithoutTime, |  | ||||||
| 			ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", |  | ||||||
| 			Path:               "/foo/bar", |  | ||||||
| 			ExcludePaths:       []string{}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             RequestLoggingFormatWithoutTime, |  | ||||||
| 			ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", |  | ||||||
| 			Path:               "/foo/bar", |  | ||||||
| 			ExcludePaths:       []string{"/ping"}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             RequestLoggingFormatWithoutTime, |  | ||||||
| 			ExpectedLogMessage: "", |  | ||||||
| 			Path:               "/foo/bar", |  | ||||||
| 			ExcludePaths:       []string{"/foo/bar"}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             RequestLoggingFormatWithoutTime, |  | ||||||
| 			ExpectedLogMessage: "127.0.0.1 - - [TIMELESS] test-server GET - \"/ping\" HTTP/1.1 \"\" 200 4 0.000\n", |  | ||||||
| 			Path:               "/ping", |  | ||||||
| 			ExcludePaths:       []string{}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             RequestLoggingFormatWithoutTime, |  | ||||||
| 			ExpectedLogMessage: "", |  | ||||||
| 			Path:               "/ping", |  | ||||||
| 			ExcludePaths:       []string{"/ping"}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             RequestLoggingFormatWithoutTime, |  | ||||||
| 			ExpectedLogMessage: "", |  | ||||||
| 			Path:               "/ping", |  | ||||||
| 			ExcludePaths:       []string{"/foo/bar", "/ping"}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             "{{.RequestMethod}}", |  | ||||||
| 			ExpectedLogMessage: "GET\n", |  | ||||||
| 			Path:               "/foo/bar", |  | ||||||
| 			ExcludePaths:       []string{""}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             "{{.RequestMethod}}", |  | ||||||
| 			ExpectedLogMessage: "GET\n", |  | ||||||
| 			Path:               "/foo/bar", |  | ||||||
| 			ExcludePaths:       []string{"/ping"}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             "{{.RequestMethod}}", |  | ||||||
| 			ExpectedLogMessage: "GET\n", |  | ||||||
| 			Path:               "/ping", |  | ||||||
| 			ExcludePaths:       []string{""}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			Format:             "{{.RequestMethod}}", |  | ||||||
| 			ExpectedLogMessage: "", |  | ||||||
| 			Path:               "/ping", |  | ||||||
| 			ExcludePaths:       []string{"/ping"}, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, test := range tests { |  | ||||||
| 		buf := bytes.NewBuffer(nil) |  | ||||||
| 		handler := func(w http.ResponseWriter, req *http.Request) { |  | ||||||
| 			_, ok := w.(http.Hijacker) |  | ||||||
| 			if !ok { |  | ||||||
| 				t.Error("http.Hijacker is not available") |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			_, err := w.Write([]byte("test")) |  | ||||||
| 			assert.NoError(t, err) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		logger.SetOutput(buf) |  | ||||||
| 		logger.SetReqTemplate(test.Format) |  | ||||||
| 		logger.SetExcludePaths(test.ExcludePaths) |  | ||||||
| 		h := LoggingHandler(http.HandlerFunc(handler)) |  | ||||||
| 
 |  | ||||||
| 		r, _ := http.NewRequest("GET", test.Path, nil) |  | ||||||
| 		r.RemoteAddr = "127.0.0.1" |  | ||||||
| 		r.Host = "test-server" |  | ||||||
| 
 |  | ||||||
| 		h.ServeHTTP(httptest.NewRecorder(), r) |  | ||||||
| 
 |  | ||||||
| 		actual := buf.String() |  | ||||||
| 		assert.Equal(t, test.ExpectedLogMessage, actual) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  | @ -250,9 +250,15 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | ||||||
| 	// To silence logging of health checks, register the health check handler before
 | 	// To silence logging of health checks, register the health check handler before
 | ||||||
| 	// the logging handler
 | 	// the logging handler
 | ||||||
| 	if opts.Logging.SilencePing { | 	if opts.Logging.SilencePing { | ||||||
| 		chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler) | 		chain = chain.Append( | ||||||
|  | 			middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), | ||||||
|  | 			middleware.NewRequestLogger(), | ||||||
|  | 		) | ||||||
| 	} else { | 	} else { | ||||||
| 		chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents)) | 		chain = chain.Append( | ||||||
|  | 			middleware.NewRequestLogger(), | ||||||
|  | 			middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), | ||||||
|  | 		) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) | 	chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) | ||||||
|  |  | ||||||
|  | @ -34,6 +34,9 @@ type RequestScope struct { | ||||||
| 	// SessionRevalidated indicates whether the session has been revalidated since
 | 	// SessionRevalidated indicates whether the session has been revalidated since
 | ||||||
| 	// it was loaded or not.
 | 	// it was loaded or not.
 | ||||||
| 	SessionRevalidated bool | 	SessionRevalidated bool | ||||||
|  | 
 | ||||||
|  | 	// Upstream tracks which upstream was used for this request
 | ||||||
|  | 	Upstream string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetRequestScope returns the current request scope from the given request
 | // GetRequestScope returns the current request scope from the given request
 | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -19,6 +20,17 @@ func TestMiddlewareSuite(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| func testHandler() http.Handler { | func testHandler() http.Handler { | ||||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 		rw.WriteHeader(200) | ||||||
|  | 		rw.Write([]byte("test")) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func testUpstreamHandler(upstream string) http.Handler { | ||||||
|  | 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 		scope := middlewareapi.GetRequestScope(req) | ||||||
|  | 		scope.Upstream = upstream | ||||||
|  | 
 | ||||||
|  | 		rw.WriteHeader(200) | ||||||
| 		rw.Write([]byte("test")) | 		rw.Write([]byte("test")) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,110 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"errors" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // NewRequestLogger returns middleware which logs requests
 | ||||||
|  | // It uses a custom ResponseWriter to track status code & response size details
 | ||||||
|  | func NewRequestLogger() alice.Constructor { | ||||||
|  | 	return requestLogger | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func requestLogger(next http.Handler) http.Handler { | ||||||
|  | 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 		startTime := time.Now() | ||||||
|  | 		url := *req.URL | ||||||
|  | 
 | ||||||
|  | 		responseLogger := &loggingResponse{ResponseWriter: rw} | ||||||
|  | 		next.ServeHTTP(responseLogger, req) | ||||||
|  | 
 | ||||||
|  | 		scope := middlewareapi.GetRequestScope(req) | ||||||
|  | 		// If scope is nil, this will panic.
 | ||||||
|  | 		// A scope should always be injected before this handler is called.
 | ||||||
|  | 		logger.PrintReq( | ||||||
|  | 			getUser(scope), | ||||||
|  | 			scope.Upstream, | ||||||
|  | 			req, | ||||||
|  | 			url, | ||||||
|  | 			startTime, | ||||||
|  | 			responseLogger.Status(), | ||||||
|  | 			responseLogger.Size(), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getUser(scope *middlewareapi.RequestScope) string { | ||||||
|  | 	session := scope.Session | ||||||
|  | 	if session != nil { | ||||||
|  | 		if session.Email != "" { | ||||||
|  | 			return session.Email | ||||||
|  | 		} | ||||||
|  | 		return session.User | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // loggingResponse is a custom http.ResponseWriter that allows tracking certain
 | ||||||
|  | // details for request logging.
 | ||||||
|  | type loggingResponse struct { | ||||||
|  | 	http.ResponseWriter | ||||||
|  | 
 | ||||||
|  | 	status int | ||||||
|  | 	size   int | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Write writes the response using the ResponseWriter
 | ||||||
|  | func (r *loggingResponse) Write(b []byte) (int, error) { | ||||||
|  | 	if r.status == 0 { | ||||||
|  | 		// The status will be StatusOK if WriteHeader has not been called yet
 | ||||||
|  | 		r.status = http.StatusOK | ||||||
|  | 	} | ||||||
|  | 	size, err := r.ResponseWriter.Write(b) | ||||||
|  | 	r.size += size | ||||||
|  | 	return size, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WriteHeader writes the status code for the Response
 | ||||||
|  | func (r *loggingResponse) WriteHeader(s int) { | ||||||
|  | 	r.ResponseWriter.WriteHeader(s) | ||||||
|  | 	r.status = s | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Hijack implements the `http.Hijacker` interface that actual ResponseWriters
 | ||||||
|  | // implement to support websockets
 | ||||||
|  | func (r *loggingResponse) Hijack() (net.Conn, *bufio.ReadWriter, error) { | ||||||
|  | 	if hj, ok := r.ResponseWriter.(http.Hijacker); ok { | ||||||
|  | 		return hj.Hijack() | ||||||
|  | 	} | ||||||
|  | 	return nil, nil, errors.New("http.Hijacker is not available on writer") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Flush sends any buffered data to the client. Implements the `http.Flusher`
 | ||||||
|  | // interface
 | ||||||
|  | func (r *loggingResponse) Flush() { | ||||||
|  | 	if flusher, ok := r.ResponseWriter.(http.Flusher); ok { | ||||||
|  | 		if r.status == 0 { | ||||||
|  | 			// The status will be StatusOK if WriteHeader has not been called yet
 | ||||||
|  | 			r.status = http.StatusOK | ||||||
|  | 		} | ||||||
|  | 		flusher.Flush() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Status returns the response status code
 | ||||||
|  | func (r *loggingResponse) Status() int { | ||||||
|  | 	return r.status | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Size returns the response size
 | ||||||
|  | func (r *loggingResponse) Size() int { | ||||||
|  | 	return r.size | ||||||
|  | } | ||||||
|  | @ -0,0 +1,121 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const RequestLoggingFormatWithoutTime = "{{.Client}} - {{.Username}} [TIMELESS] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Request logger suite", func() { | ||||||
|  | 	type requestLoggerTableInput struct { | ||||||
|  | 		Format             string | ||||||
|  | 		ExpectedLogMessage string | ||||||
|  | 		Path               string | ||||||
|  | 		ExcludePaths       []string | ||||||
|  | 		Upstream           string | ||||||
|  | 		Session            *sessions.SessionState | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DescribeTable("when service a request", | ||||||
|  | 		func(in *requestLoggerTableInput) { | ||||||
|  | 			buf := bytes.NewBuffer(nil) | ||||||
|  | 			logger.SetOutput(buf) | ||||||
|  | 			logger.SetReqTemplate(in.Format) | ||||||
|  | 			logger.SetExcludePaths(in.ExcludePaths) | ||||||
|  | 
 | ||||||
|  | 			req, err := http.NewRequest("GET", in.Path, nil) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			req.RemoteAddr = "127.0.0.1" | ||||||
|  | 			req.Host = "test-server" | ||||||
|  | 
 | ||||||
|  | 			scope := &middlewareapi.RequestScope{Session: in.Session} | ||||||
|  | 			req = middlewareapi.AddRequestScope(req, scope) | ||||||
|  | 
 | ||||||
|  | 			handler := NewRequestLogger()(testUpstreamHandler(in.Upstream)) | ||||||
|  | 			handler.ServeHTTP(httptest.NewRecorder(), req) | ||||||
|  | 
 | ||||||
|  | 			Expect(buf.String()).To(Equal(in.ExpectedLogMessage)) | ||||||
|  | 		}, | ||||||
|  | 		Entry("standard request", &requestLoggerTableInput{ | ||||||
|  | 			Format:             RequestLoggingFormatWithoutTime, | ||||||
|  | 			ExpectedLogMessage: "127.0.0.1 - standard.user [TIMELESS] test-server GET standard \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||||
|  | 			Path:               "/foo/bar", | ||||||
|  | 			ExcludePaths:       []string{}, | ||||||
|  | 			Upstream:           "standard", | ||||||
|  | 			Session:            &sessions.SessionState{User: "standard.user"}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with unrelated path excluded", &requestLoggerTableInput{ | ||||||
|  | 			Format:             RequestLoggingFormatWithoutTime, | ||||||
|  | 			ExpectedLogMessage: "127.0.0.1 - unrelated.exclusion [TIMELESS] test-server GET unrelated \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||||
|  | 			Path:               "/foo/bar", | ||||||
|  | 			ExcludePaths:       []string{"/ping"}, | ||||||
|  | 			Upstream:           "unrelated", | ||||||
|  | 			Session:            &sessions.SessionState{User: "unrelated.exclusion"}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with path as the sole exclusion", &requestLoggerTableInput{ | ||||||
|  | 			Format:             RequestLoggingFormatWithoutTime, | ||||||
|  | 			ExpectedLogMessage: "", | ||||||
|  | 			Path:               "/foo/bar", | ||||||
|  | 			ExcludePaths:       []string{"/foo/bar"}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("ping path", &requestLoggerTableInput{ | ||||||
|  | 			Format:             RequestLoggingFormatWithoutTime, | ||||||
|  | 			ExpectedLogMessage: "127.0.0.1 - mr.ping [TIMELESS] test-server GET - \"/ping\" HTTP/1.1 \"\" 200 4 0.000\n", | ||||||
|  | 			Path:               "/ping", | ||||||
|  | 			ExcludePaths:       []string{}, | ||||||
|  | 			Upstream:           "", | ||||||
|  | 			Session:            &sessions.SessionState{User: "mr.ping"}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("ping path but excluded", &requestLoggerTableInput{ | ||||||
|  | 			Format:             RequestLoggingFormatWithoutTime, | ||||||
|  | 			ExpectedLogMessage: "", | ||||||
|  | 			Path:               "/ping", | ||||||
|  | 			ExcludePaths:       []string{"/ping"}, | ||||||
|  | 			Upstream:           "", | ||||||
|  | 			Session:            &sessions.SessionState{User: "mr.ping"}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("ping path and excluded in list", &requestLoggerTableInput{ | ||||||
|  | 			Format:             RequestLoggingFormatWithoutTime, | ||||||
|  | 			ExpectedLogMessage: "", | ||||||
|  | 			Path:               "/ping", | ||||||
|  | 			ExcludePaths:       []string{"/foo/bar", "/ping"}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("custom format", &requestLoggerTableInput{ | ||||||
|  | 			Format:             "{{.RequestMethod}} {{.Username}} {{.Upstream}}", | ||||||
|  | 			ExpectedLogMessage: "GET custom.format custom\n", | ||||||
|  | 			Path:               "/foo/bar", | ||||||
|  | 			ExcludePaths:       []string{""}, | ||||||
|  | 			Upstream:           "custom", | ||||||
|  | 			Session:            &sessions.SessionState{User: "custom.format"}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("custom format with unrelated exclusion", &requestLoggerTableInput{ | ||||||
|  | 			Format:             "{{.RequestMethod}} {{.Username}} {{.Upstream}}", | ||||||
|  | 			ExpectedLogMessage: "GET custom.format custom\n", | ||||||
|  | 			Path:               "/foo/bar", | ||||||
|  | 			ExcludePaths:       []string{"/ping"}, | ||||||
|  | 			Upstream:           "custom", | ||||||
|  | 			Session:            &sessions.SessionState{User: "custom.format"}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("custom format ping path", &requestLoggerTableInput{ | ||||||
|  | 			Format:             "{{.RequestMethod}}", | ||||||
|  | 			ExpectedLogMessage: "GET\n", | ||||||
|  | 			Path:               "/ping", | ||||||
|  | 			ExcludePaths:       []string{""}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("custom format ping path excluded", &requestLoggerTableInput{ | ||||||
|  | 			Format:             "{{.RequestMethod}}", | ||||||
|  | 			ExpectedLogMessage: "", | ||||||
|  | 			Path:               "/ping", | ||||||
|  | 			ExcludePaths:       []string{"/ping"}, | ||||||
|  | 		}), | ||||||
|  | 	) | ||||||
|  | }) | ||||||
|  | @ -4,6 +4,8 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const fileScheme = "file" | const fileScheme = "file" | ||||||
|  | @ -37,6 +39,10 @@ type fileServer struct { | ||||||
| // ServeHTTP proxies requests to the upstream provider while signing the
 | // ServeHTTP proxies requests to the upstream provider while signing the
 | ||||||
| // request headers
 | // request headers
 | ||||||
| func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	rw.Header().Set("GAP-Upstream-Address", u.upstream) | 	scope := middleware.GetRequestScope(req) | ||||||
|  | 	// If scope is nil, this will panic.
 | ||||||
|  | 	// A scope should always be injected before this handler is called.
 | ||||||
|  | 	scope.Upstream = u.upstream | ||||||
|  | 
 | ||||||
| 	u.handler.ServeHTTP(rw, req) | 	u.handler.ServeHTTP(rw, req) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"os" | 	"os" | ||||||
| 
 | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -42,10 +43,14 @@ var _ = Describe("File Server Suite", func() { | ||||||
| 	DescribeTable("fileServer ServeHTTP", | 	DescribeTable("fileServer ServeHTTP", | ||||||
| 		func(requestPath string, expectedResponseCode int, expectedBody string) { | 		func(requestPath string, expectedResponseCode int, expectedBody string) { | ||||||
| 			req := httptest.NewRequest("", requestPath, nil) | 			req := httptest.NewRequest("", requestPath, nil) | ||||||
|  | 			req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||||
|  | 
 | ||||||
| 			rw := httptest.NewRecorder() | 			rw := httptest.NewRecorder() | ||||||
| 			handler.ServeHTTP(rw, req) | 			handler.ServeHTTP(rw, req) | ||||||
| 
 | 
 | ||||||
| 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | 			scope := middlewareapi.GetRequestScope(req) | ||||||
|  | 			Expect(scope.Upstream).To(Equal(id)) | ||||||
|  | 
 | ||||||
| 			Expect(rw.Code).To(Equal(expectedResponseCode)) | 			Expect(rw.Code).To(Equal(expectedResponseCode)) | ||||||
| 			Expect(rw.Body.String()).To(Equal(expectedBody)) | 			Expect(rw.Body.String()).To(Equal(expectedBody)) | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
|  | @ -8,6 +8,7 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
| 	"github.com/yhat/wsutil" | 	"github.com/yhat/wsutil" | ||||||
| ) | ) | ||||||
|  | @ -76,7 +77,12 @@ type httpUpstreamProxy struct { | ||||||
| // ServeHTTP proxies requests to the upstream provider while signing the
 | // ServeHTTP proxies requests to the upstream provider while signing the
 | ||||||
| // request headers
 | // request headers
 | ||||||
| func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	rw.Header().Set("GAP-Upstream-Address", h.upstream) | 	scope := middleware.GetRequestScope(req) | ||||||
|  | 	// If scope is nil, this will panic.
 | ||||||
|  | 	// A scope should always be injected before this handler is called.
 | ||||||
|  | 	scope.Upstream = h.upstream | ||||||
|  | 
 | ||||||
|  | 	// TODO (@NickMeves) - Deprecate GAP-Signature & remove GAP-Auth
 | ||||||
| 	if h.auth != nil { | 	if h.auth != nil { | ||||||
| 		req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) | 		req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) | ||||||
| 		h.auth.SignRequest(req) | 		h.auth.SignRequest(req) | ||||||
|  |  | ||||||
|  | @ -13,7 +13,9 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -36,6 +38,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 		signatureData    *options.SignatureData | 		signatureData    *options.SignatureData | ||||||
| 		existingHeaders  map[string]string | 		existingHeaders  map[string]string | ||||||
| 		expectedResponse testHTTPResponse | 		expectedResponse testHTTPResponse | ||||||
|  | 		expectedUpstream string | ||||||
| 		errorHandler     ProxyErrorHandler | 		errorHandler     ProxyErrorHandler | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -50,6 +53,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 				req.Header.Add(key, value) | 				req.Header.Add(key, value) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | 			req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||||
| 			rw := httptest.NewRecorder() | 			rw := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| 			flush := options.Duration(1 * time.Second) | 			flush := options.Duration(1 * time.Second) | ||||||
|  | @ -71,6 +75,9 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 
 | 
 | ||||||
| 			Expect(rw.Code).To(Equal(in.expectedResponse.code)) | 			Expect(rw.Code).To(Equal(in.expectedResponse.code)) | ||||||
| 
 | 
 | ||||||
|  | 			scope := middlewareapi.GetRequestScope(req) | ||||||
|  | 			Expect(scope.Upstream).To(Equal(in.expectedUpstream)) | ||||||
|  | 
 | ||||||
| 			// Delete extra headers that aren't relevant to tests
 | 			// Delete extra headers that aren't relevant to tests
 | ||||||
| 			testSanitizeResponseHeader(rw.Header()) | 			testSanitizeResponseHeader(rw.Header()) | ||||||
| 			Expect(rw.Header()).To(Equal(in.expectedResponse.header)) | 			Expect(rw.Header()).To(Equal(in.expectedResponse.header)) | ||||||
|  | @ -97,7 +104,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"default"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -109,6 +115,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 					RequestURI: "http://example.localhost/foo", | 					RequestURI: "http://example.localhost/foo", | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 			expectedUpstream: "default", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("request a path with encoded slashes", &httpUpstreamTableInput{ | 		Entry("request a path with encoded slashes", &httpUpstreamTableInput{ | ||||||
| 			id:           "encodedSlashes", | 			id:           "encodedSlashes", | ||||||
|  | @ -120,7 +127,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"encodedSlashes"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -132,6 +138,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 					RequestURI: "http://example.localhost/foo%2fbar/?baz=1", | 					RequestURI: "http://example.localhost/foo%2fbar/?baz=1", | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 			expectedUpstream: "encodedSlashes", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("when the request has a body", &httpUpstreamTableInput{ | 		Entry("when the request has a body", &httpUpstreamTableInput{ | ||||||
| 			id:           "requestWithBody", | 			id:           "requestWithBody", | ||||||
|  | @ -143,7 +150,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"requestWithBody"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -157,6 +163,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 					RequestURI: "http://example.localhost/withBody", | 					RequestURI: "http://example.localhost/withBody", | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 			expectedUpstream: "requestWithBody", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("when the upstream is unavailable", &httpUpstreamTableInput{ | 		Entry("when the upstream is unavailable", &httpUpstreamTableInput{ | ||||||
| 			id:           "unavailableUpstream", | 			id:           "unavailableUpstream", | ||||||
|  | @ -167,11 +174,10 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			errorHandler: nil, | 			errorHandler: nil, | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code:    502, | 				code:    502, | ||||||
| 				header: map[string][]string{ | 				header:  map[string][]string{}, | ||||||
| 					gapUpstream: {"unavailableUpstream"}, |  | ||||||
| 				}, |  | ||||||
| 				request: testHTTPRequest{}, | 				request: testHTTPRequest{}, | ||||||
| 			}, | 			}, | ||||||
|  | 			expectedUpstream: "unavailableUpstream", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{ | 		Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{ | ||||||
| 			id:         "withErrorHandler", | 			id:         "withErrorHandler", | ||||||
|  | @ -185,12 +191,11 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code:    502, | 				code:    502, | ||||||
| 				header: map[string][]string{ | 				header:  map[string][]string{}, | ||||||
| 					gapUpstream: {"withErrorHandler"}, |  | ||||||
| 				}, |  | ||||||
| 				raw:     "error", | 				raw:     "error", | ||||||
| 				request: testHTTPRequest{}, | 				request: testHTTPRequest{}, | ||||||
| 			}, | 			}, | ||||||
|  | 			expectedUpstream: "withErrorHandler", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a signature", &httpUpstreamTableInput{ | 		Entry("with a signature", &httpUpstreamTableInput{ | ||||||
| 			id:         "withSignature", | 			id:         "withSignature", | ||||||
|  | @ -207,7 +212,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 					gapUpstream: {"withSignature"}, |  | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
| 					Method: "GET", | 					Method: "GET", | ||||||
|  | @ -221,6 +225,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 					RequestURI: "http://example.localhost/withSignature", | 					RequestURI: "http://example.localhost/withSignature", | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 			expectedUpstream: "withSignature", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with existing headers", &httpUpstreamTableInput{ | 		Entry("with existing headers", &httpUpstreamTableInput{ | ||||||
| 			id:           "existingHeaders", | 			id:           "existingHeaders", | ||||||
|  | @ -236,7 +241,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"existingHeaders"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -251,11 +255,13 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 					RequestURI: "http://example.localhost/existingHeaders", | 					RequestURI: "http://example.localhost/existingHeaders", | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 			expectedUpstream: "existingHeaders", | ||||||
| 		}), | 		}), | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	It("ServeHTTP, when not passing a host header", func() { | 	It("ServeHTTP, when not passing a host header", func() { | ||||||
| 		req := httptest.NewRequest("", "http://example.localhost/foo", nil) | 		req := httptest.NewRequest("", "http://example.localhost/foo", nil) | ||||||
|  | 		req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||||
| 		rw := httptest.NewRecorder() | 		rw := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| 		flush := options.Duration(1 * time.Second) | 		flush := options.Duration(1 * time.Second) | ||||||
|  | @ -383,7 +389,8 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			Expect(err).ToNot(HaveOccurred()) | 			Expect(err).ToNot(HaveOccurred()) | ||||||
| 
 | 
 | ||||||
| 			handler := newHTTPUpstreamProxy(upstream, u, nil, nil) | 			handler := newHTTPUpstreamProxy(upstream, u, nil, nil) | ||||||
| 			proxyServer = httptest.NewServer(handler) | 
 | ||||||
|  | 			proxyServer = httptest.NewServer(middleware.NewScope(false)(handler)) | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
| 		AfterEach(func() { | 		AfterEach(func() { | ||||||
|  | @ -414,7 +421,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) | 			response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) | ||||||
| 			Expect(err).ToNot(HaveOccurred()) | 			Expect(err).ToNot(HaveOccurred()) | ||||||
| 			Expect(response.StatusCode).To(Equal(200)) | 			Expect(response.StatusCode).To(Equal(200)) | ||||||
| 			Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) |  | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| }) | }) | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 
 | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | @ -64,17 +65,24 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 	type proxyTableInput struct { | 	type proxyTableInput struct { | ||||||
| 		target   string | 		target   string | ||||||
| 		response testHTTPResponse | 		response testHTTPResponse | ||||||
|  | 		upstream string | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	DescribeTable("Proxy ServerHTTP", | 	DescribeTable("Proxy ServeHTTP", | ||||||
| 		func(in *proxyTableInput) { | 		func(in *proxyTableInput) { | ||||||
| 			req := httptest.NewRequest("", in.target, nil) | 			req := middlewareapi.AddRequestScope( | ||||||
|  | 				httptest.NewRequest("", in.target, nil), | ||||||
|  | 				&middlewareapi.RequestScope{}, | ||||||
|  | 			) | ||||||
| 			rw := httptest.NewRecorder() | 			rw := httptest.NewRecorder() | ||||||
| 			// Don't mock the remote Address
 | 			// Don't mock the remote Address
 | ||||||
| 			req.RemoteAddr = "" | 			req.RemoteAddr = "" | ||||||
| 
 | 
 | ||||||
| 			upstreamServer.ServeHTTP(rw, req) | 			upstreamServer.ServeHTTP(rw, req) | ||||||
| 
 | 
 | ||||||
|  | 			scope := middlewareapi.GetRequestScope(req) | ||||||
|  | 			Expect(scope.Upstream).To(Equal(in.upstream)) | ||||||
|  | 
 | ||||||
| 			Expect(rw.Code).To(Equal(in.response.code)) | 			Expect(rw.Code).To(Equal(in.response.code)) | ||||||
| 
 | 
 | ||||||
| 			// Delete extra headers that aren't relevant to tests
 | 			// Delete extra headers that aren't relevant to tests
 | ||||||
|  | @ -99,7 +107,6 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"http-backend"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -114,6 +121,7 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 					RequestURI: "http://example.localhost/http/1234", | 					RequestURI: "http://example.localhost/http/1234", | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 			upstream: "http-backend", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the File backend", &proxyTableInput{ | 		Entry("with a request to the File backend", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/files/foo", | 			target: "http://example.localhost/files/foo", | ||||||
|  | @ -121,31 +129,29 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					contentType: {textPlainUTF8}, | 					contentType: {textPlainUTF8}, | ||||||
| 					gapUpstream: {"file-backend"}, |  | ||||||
| 				}, | 				}, | ||||||
| 				raw: "foo", | 				raw: "foo", | ||||||
| 			}, | 			}, | ||||||
|  | 			upstream: "file-backend", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the Static backend", &proxyTableInput{ | 		Entry("with a request to the Static backend", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/static/bar", | 			target: "http://example.localhost/static/bar", | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code:   200, | 				code:   200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{}, | ||||||
| 					gapUpstream: {"static-backend"}, |  | ||||||
| 				}, |  | ||||||
| 				raw:    "Authenticated", | 				raw:    "Authenticated", | ||||||
| 			}, | 			}, | ||||||
|  | 			upstream: "static-backend", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the bad HTTP backend", &proxyTableInput{ | 		Entry("with a request to the bad HTTP backend", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/bad-http/bad", | 			target: "http://example.localhost/bad-http/bad", | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code:   502, | 				code:   502, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{}, | ||||||
| 					gapUpstream: {"bad-http-backend"}, |  | ||||||
| 				}, |  | ||||||
| 				// This tests the error handler
 | 				// This tests the error handler
 | ||||||
| 				raw: "Proxy Error", | 				raw: "Proxy Error", | ||||||
| 			}, | 			}, | ||||||
|  | 			upstream: "bad-http-backend", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the to an unregistered path", &proxyTableInput{ | 		Entry("with a request to the to an unregistered path", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/unregistered", | 			target: "http://example.localhost/unregistered", | ||||||
|  | @ -162,11 +168,10 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 			target: "http://example.localhost/single-path", | 			target: "http://example.localhost/single-path", | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code:   200, | 				code:   200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{}, | ||||||
| 					gapUpstream: {"single-path-backend"}, |  | ||||||
| 				}, |  | ||||||
| 				raw:    "Authenticated", | 				raw:    "Authenticated", | ||||||
| 			}, | 			}, | ||||||
|  | 			upstream: "single-path-backend", | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ | 		Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/single-path/unregistered", | 			target: "http://example.localhost/single-path/unregistered", | ||||||
|  |  | ||||||
|  | @ -3,6 +3,9 @@ package upstream | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const defaultStaticResponseCode = 200 | const defaultStaticResponseCode = 200 | ||||||
|  | @ -24,9 +27,16 @@ type staticResponseHandler struct { | ||||||
| 
 | 
 | ||||||
| // ServeHTTP serves a static response.
 | // ServeHTTP serves a static response.
 | ||||||
| func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	rw.Header().Set("GAP-Upstream-Address", s.upstream) | 	scope := middleware.GetRequestScope(req) | ||||||
|  | 	// If scope is nil, this will panic.
 | ||||||
|  | 	// A scope should always be injected before this handler is called.
 | ||||||
|  | 	scope.Upstream = s.upstream | ||||||
|  | 
 | ||||||
| 	rw.WriteHeader(s.code) | 	rw.WriteHeader(s.code) | ||||||
| 	fmt.Fprintf(rw, "Authenticated") | 	_, err := fmt.Fprintf(rw, "Authenticated") | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Error writing static response: %v", err) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // derefStaticCode returns the derefenced value, or the default if the value is nil
 | // derefStaticCode returns the derefenced value, or the default if the value is nil
 | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 
 | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -40,10 +41,14 @@ var _ = Describe("Static Response Suite", func() { | ||||||
| 			handler := newStaticResponseHandler(id, code) | 			handler := newStaticResponseHandler(id, code) | ||||||
| 
 | 
 | ||||||
| 			req := httptest.NewRequest("", in.requestPath, nil) | 			req := httptest.NewRequest("", in.requestPath, nil) | ||||||
|  | 			req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||||
|  | 
 | ||||||
| 			rw := httptest.NewRecorder() | 			rw := httptest.NewRecorder() | ||||||
| 			handler.ServeHTTP(rw, req) | 			handler.ServeHTTP(rw, req) | ||||||
| 
 | 
 | ||||||
| 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | 			scope := middlewareapi.GetRequestScope(req) | ||||||
|  | 			Expect(scope.Upstream).To(Equal(id)) | ||||||
|  | 
 | ||||||
| 			Expect(rw.Code).To(Equal(in.expectedCode)) | 			Expect(rw.Code).To(Equal(in.expectedCode)) | ||||||
| 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
|  | @ -59,7 +59,6 @@ const ( | ||||||
| 	acceptEncoding  = "Accept-Encoding" | 	acceptEncoding  = "Accept-Encoding" | ||||||
| 	applicationJSON = "application/json" | 	applicationJSON = "application/json" | ||||||
| 	textPlainUTF8   = "text/plain; charset=utf-8" | 	textPlainUTF8   = "text/plain; charset=utf-8" | ||||||
| 	gapUpstream     = "Gap-Upstream-Address" |  | ||||||
| 	gapAuth         = "Gap-Auth" | 	gapAuth         = "Gap-Auth" | ||||||
| 	gapSignature    = "Gap-Signature" | 	gapSignature    = "Gap-Signature" | ||||||
| ) | ) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue