Move upstream information to request scope
This commit is contained in:
		
							parent
							
								
									18cd045631
								
							
						
					
					
						commit
						2e72d151e2
					
				|  | @ -11,16 +11,15 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
 | // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
 | ||||||
| // code and body size
 | // code and body size
 | ||||||
| type responseLogger struct { | type responseLogger struct { | ||||||
| 	w        http.ResponseWriter | 	w      http.ResponseWriter | ||||||
| 	status   int | 	status int | ||||||
| 	size     int | 	size   int | ||||||
| 	upstream string |  | ||||||
| 	authInfo string |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Header returns the ResponseWriter's Header
 | // Header returns the ResponseWriter's Header
 | ||||||
|  | @ -36,19 +35,17 @@ func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err erro | ||||||
| 	return nil, nil, errors.New("http.Hijacker is not available on writer") | 	return nil, nil, errors.New("http.Hijacker is not available on writer") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
 | // extractMetadata extracts metadata from the request/reqsponse for logging
 | ||||||
| // Header
 | func extractMetadata(rw http.ResponseWriter, req *http.Request) (string, string) { | ||||||
| func (l *responseLogger) ExtractGAPMetadata() { | 	scope := middleware.GetRequestScope(req) | ||||||
| 	upstream := l.w.Header().Get("GAP-Upstream-Address") | 	upstream := scope.Upstream | ||||||
| 	if upstream != "" { | 
 | ||||||
| 		l.upstream = upstream | 	authInfo := rw.Header().Get("GAP-Auth") | ||||||
| 		l.w.Header().Del("GAP-Upstream-Address") |  | ||||||
| 	} |  | ||||||
| 	authInfo := l.w.Header().Get("GAP-Auth") |  | ||||||
| 	if authInfo != "" { | 	if authInfo != "" { | ||||||
| 		l.authInfo = authInfo | 		rw.Header().Del("GAP-Auth") | ||||||
| 		l.w.Header().Del("GAP-Auth") |  | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	return authInfo, upstream | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Write writes the response using the ResponseWriter
 | // Write writes the response using the ResponseWriter
 | ||||||
|  | @ -57,7 +54,6 @@ func (l *responseLogger) Write(b []byte) (int, error) { | ||||||
| 		// The status will be StatusOK if WriteHeader has not been called yet
 | 		// The status will be StatusOK if WriteHeader has not been called yet
 | ||||||
| 		l.status = http.StatusOK | 		l.status = http.StatusOK | ||||||
| 	} | 	} | ||||||
| 	l.ExtractGAPMetadata() |  | ||||||
| 	size, err := l.w.Write(b) | 	size, err := l.w.Write(b) | ||||||
| 	l.size += size | 	l.size += size | ||||||
| 	return size, err | 	return size, err | ||||||
|  | @ -65,7 +61,6 @@ func (l *responseLogger) Write(b []byte) (int, error) { | ||||||
| 
 | 
 | ||||||
| // WriteHeader writes the status code for the Response
 | // WriteHeader writes the status code for the Response
 | ||||||
| func (l *responseLogger) WriteHeader(s int) { | func (l *responseLogger) WriteHeader(s int) { | ||||||
| 	l.ExtractGAPMetadata() |  | ||||||
| 	l.w.WriteHeader(s) | 	l.w.WriteHeader(s) | ||||||
| 	l.status = s | 	l.status = s | ||||||
| } | } | ||||||
|  | @ -104,5 +99,7 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { | ||||||
| 	url := *req.URL | 	url := *req.URL | ||||||
| 	responseLogger := &responseLogger{w: w} | 	responseLogger := &responseLogger{w: w} | ||||||
| 	h.handler.ServeHTTP(responseLogger, req) | 	h.handler.ServeHTTP(responseLogger, req) | ||||||
| 	logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) | 
 | ||||||
|  | 	authInfo, upstream := extractMetadata(w, req) | ||||||
|  | 	logger.PrintReq(authInfo, upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -6,7 +6,9 @@ import ( | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -102,7 +104,7 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { | ||||||
| 		logger.SetOutput(buf) | 		logger.SetOutput(buf) | ||||||
| 		logger.SetReqTemplate(test.Format) | 		logger.SetReqTemplate(test.Format) | ||||||
| 		logger.SetExcludePaths(test.ExcludePaths) | 		logger.SetExcludePaths(test.ExcludePaths) | ||||||
| 		h := LoggingHandler(http.HandlerFunc(handler)) | 		h := alice.New(middleware.NewScope(), LoggingHandler).Then(http.HandlerFunc(handler)) | ||||||
| 
 | 
 | ||||||
| 		r, _ := http.NewRequest("GET", test.Path, nil) | 		r, _ := http.NewRequest("GET", test.Path, nil) | ||||||
| 		r.RemoteAddr = "127.0.0.1" | 		r.RemoteAddr = "127.0.0.1" | ||||||
|  |  | ||||||
|  | @ -21,4 +21,7 @@ type RequestScope struct { | ||||||
| 	// SessionRevalidated indicates whether the session has been revalidated since
 | 	// SessionRevalidated indicates whether the session has been revalidated since
 | ||||||
| 	// it was loaded or not.
 | 	// it was loaded or not.
 | ||||||
| 	SessionRevalidated bool | 	SessionRevalidated bool | ||||||
|  | 
 | ||||||
|  | 	// Upstream indicates which (if any) upstream server the request was proxied to.
 | ||||||
|  | 	Upstream string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -4,6 +4,8 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const fileScheme = "file" | const fileScheme = "file" | ||||||
|  | @ -37,6 +39,11 @@ type fileServer struct { | ||||||
| // ServeHTTP proxies requests to the upstream provider while signing the
 | // ServeHTTP proxies requests to the upstream provider while signing the
 | ||||||
| // request headers
 | // request headers
 | ||||||
| func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	rw.Header().Set("GAP-Upstream-Address", u.upstream) | 	scope := middleware.GetRequestScope(req) | ||||||
|  | 
 | ||||||
|  | 	// If scope is nil, this will panic.
 | ||||||
|  | 	// A scope should always be injected before this handler is called.
 | ||||||
|  | 	scope.Upstream = u.upstream | ||||||
|  | 
 | ||||||
| 	u.handler.ServeHTTP(rw, req) | 	u.handler.ServeHTTP(rw, req) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -7,6 +7,9 @@ import ( | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"os" | 	"os" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -16,6 +19,7 @@ var _ = Describe("File Server Suite", func() { | ||||||
| 	var dir string | 	var dir string | ||||||
| 	var handler http.Handler | 	var handler http.Handler | ||||||
| 	var id string | 	var id string | ||||||
|  | 	var scope *middlewareapi.RequestScope | ||||||
| 
 | 
 | ||||||
| 	const ( | 	const ( | ||||||
| 		foo          = "foo" | 		foo          = "foo" | ||||||
|  | @ -25,14 +29,24 @@ var _ = Describe("File Server Suite", func() { | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	BeforeEach(func() { | 	BeforeEach(func() { | ||||||
| 		// Generate a random id before each test to check the GAP-Upstream-Address
 | 		// Generate a random id before each test to check the upstream
 | ||||||
| 		// is being set correctly
 | 		// is being set correctly in the scope
 | ||||||
| 		idBytes := make([]byte, 16) | 		idBytes := make([]byte, 16) | ||||||
| 		_, err := io.ReadFull(rand.Reader, idBytes) | 		_, err := io.ReadFull(rand.Reader, idBytes) | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
| 		id = string(idBytes) | 		id = string(idBytes) | ||||||
| 
 | 
 | ||||||
| 		handler = newFileServer(id, "/files", filesDir) | 		scope = nil | ||||||
|  | 		// Extract the scope so that we can see that the upstream has been set
 | ||||||
|  | 		// correctly
 | ||||||
|  | 		extractScope := func(next http.Handler) http.Handler { | ||||||
|  | 			return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 				scope = middleware.GetRequestScope(req) | ||||||
|  | 				next.ServeHTTP(rw, req) | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		handler = alice.New(middleware.NewScope(), extractScope).Then(newFileServer(id, "/files", filesDir)) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	AfterEach(func() { | 	AfterEach(func() { | ||||||
|  | @ -45,7 +59,7 @@ var _ = Describe("File Server Suite", func() { | ||||||
| 			rw := httptest.NewRecorder() | 			rw := httptest.NewRecorder() | ||||||
| 			handler.ServeHTTP(rw, req) | 			handler.ServeHTTP(rw, req) | ||||||
| 
 | 
 | ||||||
| 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | 			Expect(scope.Upstream).To(Equal(id)) | ||||||
| 			Expect(rw.Code).To(Equal(expectedResponseCode)) | 			Expect(rw.Code).To(Equal(expectedResponseCode)) | ||||||
| 			Expect(rw.Body.String()).To(Equal(expectedBody)) | 			Expect(rw.Body.String()).To(Equal(expectedBody)) | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
|  | @ -10,6 +10,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| 	"github.com/yhat/wsutil" | 	"github.com/yhat/wsutil" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -77,7 +78,12 @@ type httpUpstreamProxy struct { | ||||||
| // ServeHTTP proxies requests to the upstream provider while signing the
 | // ServeHTTP proxies requests to the upstream provider while signing the
 | ||||||
| // request headers
 | // request headers
 | ||||||
| func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	rw.Header().Set("GAP-Upstream-Address", h.upstream) | 	scope := middleware.GetRequestScope(req) | ||||||
|  | 
 | ||||||
|  | 	// If scope is nil, this will panic.
 | ||||||
|  | 	// A scope should always be injected before this handler is called.
 | ||||||
|  | 	scope.Upstream = h.upstream | ||||||
|  | 
 | ||||||
| 	if h.auth != nil { | 	if h.auth != nil { | ||||||
| 		req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) | 		req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) | ||||||
| 		h.auth.SignRequest(req) | 		h.auth.SignRequest(req) | ||||||
|  |  | ||||||
|  | @ -13,7 +13,10 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -35,6 +38,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 		body             []byte | 		body             []byte | ||||||
| 		signatureData    *options.SignatureData | 		signatureData    *options.SignatureData | ||||||
| 		existingHeaders  map[string]string | 		existingHeaders  map[string]string | ||||||
|  | 		expectedUpstream string | ||||||
| 		expectedResponse testHTTPResponse | 		expectedResponse testHTTPResponse | ||||||
| 		errorHandler     ProxyErrorHandler | 		errorHandler     ProxyErrorHandler | ||||||
| 	} | 	} | ||||||
|  | @ -66,10 +70,21 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			u, err := url.Parse(*in.serverAddr) | 			u, err := url.Parse(*in.serverAddr) | ||||||
| 			Expect(err).ToNot(HaveOccurred()) | 			Expect(err).ToNot(HaveOccurred()) | ||||||
| 
 | 
 | ||||||
| 			handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler) | 			var scope *middlewareapi.RequestScope | ||||||
|  | 			// Extract the scope so that we can see that the upstream has been set
 | ||||||
|  | 			// correctly
 | ||||||
|  | 			extractScope := func(next http.Handler) http.Handler { | ||||||
|  | 				return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 					scope = middleware.GetRequestScope(req) | ||||||
|  | 					next.ServeHTTP(rw, req) | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			handler := alice.New(middleware.NewScope(), extractScope).Then(newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler)) | ||||||
| 			handler.ServeHTTP(rw, req) | 			handler.ServeHTTP(rw, req) | ||||||
| 
 | 
 | ||||||
| 			Expect(rw.Code).To(Equal(in.expectedResponse.code)) | 			Expect(rw.Code).To(Equal(in.expectedResponse.code)) | ||||||
|  | 			Expect(scope.Upstream).To(Equal(in.expectedUpstream)) | ||||||
| 
 | 
 | ||||||
| 			// Delete extra headers that aren't relevant to tests
 | 			// Delete extra headers that aren't relevant to tests
 | ||||||
| 			testSanitizeResponseHeader(rw.Header()) | 			testSanitizeResponseHeader(rw.Header()) | ||||||
|  | @ -88,16 +103,16 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			Expect(request).To(Equal(in.expectedResponse.request)) | 			Expect(request).To(Equal(in.expectedResponse.request)) | ||||||
| 		}, | 		}, | ||||||
| 		Entry("request a path on the server", &httpUpstreamTableInput{ | 		Entry("request a path on the server", &httpUpstreamTableInput{ | ||||||
| 			id:           "default", | 			id:               "default", | ||||||
| 			serverAddr:   &serverAddr, | 			serverAddr:       &serverAddr, | ||||||
| 			target:       "http://example.localhost/foo", | 			target:           "http://example.localhost/foo", | ||||||
| 			method:       "GET", | 			method:           "GET", | ||||||
| 			body:         []byte{}, | 			body:             []byte{}, | ||||||
| 			errorHandler: nil, | 			errorHandler:     nil, | ||||||
|  | 			expectedUpstream: "default", | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"default"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -111,16 +126,16 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("request a path with encoded slashes", &httpUpstreamTableInput{ | 		Entry("request a path with encoded slashes", &httpUpstreamTableInput{ | ||||||
| 			id:           "encodedSlashes", | 			id:               "encodedSlashes", | ||||||
| 			serverAddr:   &serverAddr, | 			serverAddr:       &serverAddr, | ||||||
| 			target:       "http://example.localhost/foo%2fbar/?baz=1", | 			target:           "http://example.localhost/foo%2fbar/?baz=1", | ||||||
| 			method:       "GET", | 			method:           "GET", | ||||||
| 			body:         []byte{}, | 			body:             []byte{}, | ||||||
| 			errorHandler: nil, | 			errorHandler:     nil, | ||||||
|  | 			expectedUpstream: "encodedSlashes", | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"encodedSlashes"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -134,16 +149,16 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("when the request has a body", &httpUpstreamTableInput{ | 		Entry("when the request has a body", &httpUpstreamTableInput{ | ||||||
| 			id:           "requestWithBody", | 			id:               "requestWithBody", | ||||||
| 			serverAddr:   &serverAddr, | 			serverAddr:       &serverAddr, | ||||||
| 			target:       "http://example.localhost/withBody", | 			target:           "http://example.localhost/withBody", | ||||||
| 			method:       "POST", | 			method:           "POST", | ||||||
| 			body:         []byte("body"), | 			body:             []byte("body"), | ||||||
| 			errorHandler: nil, | 			errorHandler:     nil, | ||||||
|  | 			expectedUpstream: "requestWithBody", | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"requestWithBody"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -159,17 +174,16 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("when the upstream is unavailable", &httpUpstreamTableInput{ | 		Entry("when the upstream is unavailable", &httpUpstreamTableInput{ | ||||||
| 			id:           "unavailableUpstream", | 			id:               "unavailableUpstream", | ||||||
| 			serverAddr:   &invalidServer, | 			serverAddr:       &invalidServer, | ||||||
| 			target:       "http://example.localhost/unavailableUpstream", | 			target:           "http://example.localhost/unavailableUpstream", | ||||||
| 			method:       "GET", | 			method:           "GET", | ||||||
| 			body:         []byte{}, | 			body:             []byte{}, | ||||||
| 			errorHandler: nil, | 			errorHandler:     nil, | ||||||
|  | 			expectedUpstream: "unavailableUpstream", | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 502, | 				code:    502, | ||||||
| 				header: map[string][]string{ | 				header:  map[string][]string{}, | ||||||
| 					gapUpstream: {"unavailableUpstream"}, |  | ||||||
| 				}, |  | ||||||
| 				request: testHTTPRequest{}, | 				request: testHTTPRequest{}, | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
|  | @ -183,11 +197,10 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 				rw.WriteHeader(502) | 				rw.WriteHeader(502) | ||||||
| 				rw.Write([]byte("error")) | 				rw.Write([]byte("error")) | ||||||
| 			}, | 			}, | ||||||
|  | 			expectedUpstream: "withErrorHandler", | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 502, | 				code:    502, | ||||||
| 				header: map[string][]string{ | 				header:  map[string][]string{}, | ||||||
| 					gapUpstream: {"withErrorHandler"}, |  | ||||||
| 				}, |  | ||||||
| 				raw:     "error", | 				raw:     "error", | ||||||
| 				request: testHTTPRequest{}, | 				request: testHTTPRequest{}, | ||||||
| 			}, | 			}, | ||||||
|  | @ -202,12 +215,12 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 				Hash: crypto.SHA256, | 				Hash: crypto.SHA256, | ||||||
| 				Key:  "key", | 				Key:  "key", | ||||||
| 			}, | 			}, | ||||||
| 			errorHandler: nil, | 			errorHandler:     nil, | ||||||
|  | 			expectedUpstream: "withSignature", | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 					gapUpstream: {"withSignature"}, |  | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
| 					Method: "GET", | 					Method: "GET", | ||||||
|  | @ -223,12 +236,13 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with existing headers", &httpUpstreamTableInput{ | 		Entry("with existing headers", &httpUpstreamTableInput{ | ||||||
| 			id:           "existingHeaders", | 			id:               "existingHeaders", | ||||||
| 			serverAddr:   &serverAddr, | 			serverAddr:       &serverAddr, | ||||||
| 			target:       "http://example.localhost/existingHeaders", | 			target:           "http://example.localhost/existingHeaders", | ||||||
| 			method:       "GET", | 			method:           "GET", | ||||||
| 			body:         []byte{}, | 			body:             []byte{}, | ||||||
| 			errorHandler: nil, | 			errorHandler:     nil, | ||||||
|  | 			expectedUpstream: "existingHeaders", | ||||||
| 			existingHeaders: map[string]string{ | 			existingHeaders: map[string]string{ | ||||||
| 				"Header1": "value1", | 				"Header1": "value1", | ||||||
| 				"Header2": "value2", | 				"Header2": "value2", | ||||||
|  | @ -236,7 +250,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			expectedResponse: testHTTPResponse{ | 			expectedResponse: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"existingHeaders"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -274,18 +287,21 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 		httpUpstream, ok := handler.(*httpUpstreamProxy) | 		httpUpstream, ok := handler.(*httpUpstreamProxy) | ||||||
| 		Expect(ok).To(BeTrue()) | 		Expect(ok).To(BeTrue()) | ||||||
| 
 | 
 | ||||||
|  | 		var gotRequest *http.Request | ||||||
| 		// Override the handler to just run the director and not actually send the request
 | 		// Override the handler to just run the director and not actually send the request
 | ||||||
| 		requestInterceptor := func(h http.Handler) http.Handler { | 		requestInterceptor := func(h http.Handler) http.Handler { | ||||||
| 			return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { | 			return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { | ||||||
| 				proxy, ok := h.(*httputil.ReverseProxy) | 				proxy, ok := h.(*httputil.ReverseProxy) | ||||||
| 				Expect(ok).To(BeTrue()) | 				Expect(ok).To(BeTrue()) | ||||||
| 				proxy.Director(req) | 				proxy.Director(req) | ||||||
|  | 
 | ||||||
|  | 				gotRequest = req | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 		httpUpstream.handler = requestInterceptor(httpUpstream.handler) | 		httpUpstream.handler = requestInterceptor(httpUpstream.handler) | ||||||
| 
 | 
 | ||||||
| 		httpUpstream.ServeHTTP(rw, req) | 		alice.New(middleware.NewScope()).Then(httpUpstream).ServeHTTP(rw, req) | ||||||
| 		Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://"))) | 		Expect(gotRequest.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://"))) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	type newUpstreamTableInput struct { | 	type newUpstreamTableInput struct { | ||||||
|  | @ -368,6 +384,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 
 | 
 | ||||||
| 	Context("with a websocket proxy", func() { | 	Context("with a websocket proxy", func() { | ||||||
| 		var proxyServer *httptest.Server | 		var proxyServer *httptest.Server | ||||||
|  | 		var scope *middlewareapi.RequestScope | ||||||
| 
 | 
 | ||||||
| 		BeforeEach(func() { | 		BeforeEach(func() { | ||||||
| 			flush := 1 * time.Second | 			flush := 1 * time.Second | ||||||
|  | @ -382,7 +399,17 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			u, err := url.Parse(serverAddr) | 			u, err := url.Parse(serverAddr) | ||||||
| 			Expect(err).ToNot(HaveOccurred()) | 			Expect(err).ToNot(HaveOccurred()) | ||||||
| 
 | 
 | ||||||
| 			handler := newHTTPUpstreamProxy(upstream, u, nil, nil) | 			scope = nil | ||||||
|  | 			// Extract the scope so that we can see that the upstream has been set
 | ||||||
|  | 			// correctly
 | ||||||
|  | 			extractScope := func(next http.Handler) http.Handler { | ||||||
|  | 				return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 					scope = middleware.GetRequestScope(req) | ||||||
|  | 					next.ServeHTTP(rw, req) | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			handler := alice.New(middleware.NewScope(), extractScope).Then(newHTTPUpstreamProxy(upstream, u, nil, nil)) | ||||||
| 			proxyServer = httptest.NewServer(handler) | 			proxyServer = httptest.NewServer(handler) | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
|  | @ -414,7 +441,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) | 			response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) | ||||||
| 			Expect(err).ToNot(HaveOccurred()) | 			Expect(err).ToNot(HaveOccurred()) | ||||||
| 			Expect(response.StatusCode).To(Equal(200)) | 			Expect(response.StatusCode).To(Equal(200)) | ||||||
| 			Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) | 			Expect(scope.Upstream).To(Equal("websocketProxy")) | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| }) | }) | ||||||
|  |  | ||||||
|  | @ -8,7 +8,10 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -16,6 +19,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| var _ = Describe("Proxy Suite", func() { | var _ = Describe("Proxy Suite", func() { | ||||||
| 	var upstreamServer http.Handler | 	var upstreamServer http.Handler | ||||||
|  | 	var scope *middlewareapi.RequestScope | ||||||
| 
 | 
 | ||||||
| 	BeforeEach(func() { | 	BeforeEach(func() { | ||||||
| 		sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} | 		sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} | ||||||
|  | @ -56,12 +60,25 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) | 		proxyServer, err := NewProxy(upstreams, sigData, errorHandler) | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		scope = nil | ||||||
|  | 		// Extract the scope so that we can see that the upstream has been set
 | ||||||
|  | 		// correctly
 | ||||||
|  | 		extractScope := func(next http.Handler) http.Handler { | ||||||
|  | 			return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 				scope = middleware.GetRequestScope(req) | ||||||
|  | 				next.ServeHTTP(rw, req) | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		upstreamServer = alice.New(middleware.NewScope(), extractScope).Then(proxyServer) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	type proxyTableInput struct { | 	type proxyTableInput struct { | ||||||
| 		target   string | 		target   string | ||||||
|  | 		upstream string | ||||||
| 		response testHTTPResponse | 		response testHTTPResponse | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -75,6 +92,7 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 			upstreamServer.ServeHTTP(rw, req) | 			upstreamServer.ServeHTTP(rw, req) | ||||||
| 
 | 
 | ||||||
| 			Expect(rw.Code).To(Equal(in.response.code)) | 			Expect(rw.Code).To(Equal(in.response.code)) | ||||||
|  | 			Expect(scope.Upstream).To(Equal(in.upstream)) | ||||||
| 
 | 
 | ||||||
| 			// Delete extra headers that aren't relevant to tests
 | 			// Delete extra headers that aren't relevant to tests
 | ||||||
| 			testSanitizeResponseHeader(rw.Header()) | 			testSanitizeResponseHeader(rw.Header()) | ||||||
|  | @ -94,11 +112,11 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 			Expect(request).To(Equal(in.response.request)) | 			Expect(request).To(Equal(in.response.request)) | ||||||
| 		}, | 		}, | ||||||
| 		Entry("with a request to the HTTP service", &proxyTableInput{ | 		Entry("with a request to the HTTP service", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/http/1234", | 			target:   "http://example.localhost/http/1234", | ||||||
|  | 			upstream: "http-backend", | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					gapUpstream: {"http-backend"}, |  | ||||||
| 					contentType: {applicationJSON}, | 					contentType: {applicationJSON}, | ||||||
| 				}, | 				}, | ||||||
| 				request: testHTTPRequest{ | 				request: testHTTPRequest{ | ||||||
|  | @ -115,33 +133,31 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the File backend", &proxyTableInput{ | 		Entry("with a request to the File backend", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/files/foo", | 			target:   "http://example.localhost/files/foo", | ||||||
|  | 			upstream: "file-backend", | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code: 200, | 				code: 200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{ | ||||||
| 					contentType: {textPlainUTF8}, | 					contentType: {textPlainUTF8}, | ||||||
| 					gapUpstream: {"file-backend"}, |  | ||||||
| 				}, | 				}, | ||||||
| 				raw: "foo", | 				raw: "foo", | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the Static backend", &proxyTableInput{ | 		Entry("with a request to the Static backend", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/static/bar", | 			target:   "http://example.localhost/static/bar", | ||||||
|  | 			upstream: "static-backend", | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code: 200, | 				code:   200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{}, | ||||||
| 					gapUpstream: {"static-backend"}, | 				raw:    "Authenticated", | ||||||
| 				}, |  | ||||||
| 				raw: "Authenticated", |  | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the bad HTTP backend", &proxyTableInput{ | 		Entry("with a request to the bad HTTP backend", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/bad-http/bad", | 			target:   "http://example.localhost/bad-http/bad", | ||||||
|  | 			upstream: "bad-http-backend", | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code: 502, | 				code:   502, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{}, | ||||||
| 					gapUpstream: {"bad-http-backend"}, |  | ||||||
| 				}, |  | ||||||
| 				// This tests the error handler
 | 				// This tests the error handler
 | ||||||
| 				raw: "Bad Gateway\nError proxying to upstream server\nprefix", | 				raw: "Bad Gateway\nError proxying to upstream server\nprefix", | ||||||
| 			}, | 			}, | ||||||
|  | @ -158,13 +174,12 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the to backend registered to a single path", &proxyTableInput{ | 		Entry("with a request to the to backend registered to a single path", &proxyTableInput{ | ||||||
| 			target: "http://example.localhost/single-path", | 			target:   "http://example.localhost/single-path", | ||||||
|  | 			upstream: "single-path-backend", | ||||||
| 			response: testHTTPResponse{ | 			response: testHTTPResponse{ | ||||||
| 				code: 200, | 				code:   200, | ||||||
| 				header: map[string][]string{ | 				header: map[string][]string{}, | ||||||
| 					gapUpstream: {"single-path-backend"}, | 				raw:    "Authenticated", | ||||||
| 				}, |  | ||||||
| 				raw: "Authenticated", |  | ||||||
| 			}, | 			}, | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ | 		Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ | ||||||
|  |  | ||||||
|  | @ -3,6 +3,8 @@ package upstream | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const defaultStaticResponseCode = 200 | const defaultStaticResponseCode = 200 | ||||||
|  | @ -24,7 +26,12 @@ type staticResponseHandler struct { | ||||||
| 
 | 
 | ||||||
| // ServeHTTP serves a static response.
 | // ServeHTTP serves a static response.
 | ||||||
| func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	rw.Header().Set("GAP-Upstream-Address", s.upstream) | 	scope := middleware.GetRequestScope(req) | ||||||
|  | 
 | ||||||
|  | 	// If scope is nil, this will panic.
 | ||||||
|  | 	// A scope should always be injected before this handler is called.
 | ||||||
|  | 	scope.Upstream = s.upstream | ||||||
|  | 
 | ||||||
| 	rw.WriteHeader(s.code) | 	rw.WriteHeader(s.code) | ||||||
| 	fmt.Fprintf(rw, "Authenticated") | 	fmt.Fprintf(rw, "Authenticated") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -6,6 +6,9 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -16,8 +19,8 @@ var _ = Describe("Static Response Suite", func() { | ||||||
| 	var id string | 	var id string | ||||||
| 
 | 
 | ||||||
| 	BeforeEach(func() { | 	BeforeEach(func() { | ||||||
| 		// Generate a random id before each test to check the GAP-Upstream-Address
 | 		// Generate a random id before each test to check the upstream
 | ||||||
| 		// is being set correctly
 | 		// is being set correctly in the scope
 | ||||||
| 		idBytes := make([]byte, 16) | 		idBytes := make([]byte, 16) | ||||||
| 		_, err := io.ReadFull(rand.Reader, idBytes) | 		_, err := io.ReadFull(rand.Reader, idBytes) | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | @ -37,13 +40,24 @@ var _ = Describe("Static Response Suite", func() { | ||||||
| 			if in.staticCode != 0 { | 			if in.staticCode != 0 { | ||||||
| 				code = &in.staticCode | 				code = &in.staticCode | ||||||
| 			} | 			} | ||||||
| 			handler := newStaticResponseHandler(id, code) | 
 | ||||||
|  | 			var scope *middlewareapi.RequestScope | ||||||
|  | 			// Extract the scope so that we can see that the upstream has been set
 | ||||||
|  | 			// correctly
 | ||||||
|  | 			extractScope := func(next http.Handler) http.Handler { | ||||||
|  | 				return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 					scope = middleware.GetRequestScope(req) | ||||||
|  | 					next.ServeHTTP(rw, req) | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			handler := alice.New(middleware.NewScope(), extractScope).Then(newStaticResponseHandler(id, code)) | ||||||
| 
 | 
 | ||||||
| 			req := httptest.NewRequest("", in.requestPath, nil) | 			req := httptest.NewRequest("", in.requestPath, nil) | ||||||
| 			rw := httptest.NewRecorder() | 			rw := httptest.NewRecorder() | ||||||
| 			handler.ServeHTTP(rw, req) | 			handler.ServeHTTP(rw, req) | ||||||
| 
 | 
 | ||||||
| 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | 			Expect(scope.Upstream).To(Equal(id)) | ||||||
| 			Expect(rw.Code).To(Equal(in.expectedCode)) | 			Expect(rw.Code).To(Equal(in.expectedCode)) | ||||||
| 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
|  | @ -58,7 +58,6 @@ const ( | ||||||
| 	acceptEncoding  = "Accept-Encoding" | 	acceptEncoding  = "Accept-Encoding" | ||||||
| 	applicationJSON = "application/json" | 	applicationJSON = "application/json" | ||||||
| 	textPlainUTF8   = "text/plain; charset=utf-8" | 	textPlainUTF8   = "text/plain; charset=utf-8" | ||||||
| 	gapUpstream     = "Gap-Upstream-Address" |  | ||||||
| 	gapAuth         = "Gap-Auth" | 	gapAuth         = "Gap-Auth" | ||||||
| 	gapSignature    = "Gap-Signature" | 	gapSignature    = "Gap-Signature" | ||||||
| ) | ) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue