Move upstream information to request scope
This commit is contained in:
		
							parent
							
								
									18cd045631
								
							
						
					
					
						commit
						2e72d151e2
					
				|  | @ -11,6 +11,7 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| ) | ||||
| 
 | ||||
| // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
 | ||||
|  | @ -19,8 +20,6 @@ type responseLogger struct { | |||
| 	w      http.ResponseWriter | ||||
| 	status int | ||||
| 	size   int | ||||
| 	upstream string | ||||
| 	authInfo string | ||||
| } | ||||
| 
 | ||||
| // Header returns the ResponseWriter's Header
 | ||||
|  | @ -36,19 +35,17 @@ func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err erro | |||
| 	return nil, nil, errors.New("http.Hijacker is not available on writer") | ||||
| } | ||||
| 
 | ||||
| // ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
 | ||||
| // Header
 | ||||
| func (l *responseLogger) ExtractGAPMetadata() { | ||||
| 	upstream := l.w.Header().Get("GAP-Upstream-Address") | ||||
| 	if upstream != "" { | ||||
| 		l.upstream = upstream | ||||
| 		l.w.Header().Del("GAP-Upstream-Address") | ||||
| 	} | ||||
| 	authInfo := l.w.Header().Get("GAP-Auth") | ||||
| // extractMetadata extracts metadata from the request/reqsponse for logging
 | ||||
| func extractMetadata(rw http.ResponseWriter, req *http.Request) (string, string) { | ||||
| 	scope := middleware.GetRequestScope(req) | ||||
| 	upstream := scope.Upstream | ||||
| 
 | ||||
| 	authInfo := rw.Header().Get("GAP-Auth") | ||||
| 	if authInfo != "" { | ||||
| 		l.authInfo = authInfo | ||||
| 		l.w.Header().Del("GAP-Auth") | ||||
| 		rw.Header().Del("GAP-Auth") | ||||
| 	} | ||||
| 
 | ||||
| 	return authInfo, upstream | ||||
| } | ||||
| 
 | ||||
| // Write writes the response using the ResponseWriter
 | ||||
|  | @ -57,7 +54,6 @@ func (l *responseLogger) Write(b []byte) (int, error) { | |||
| 		// The status will be StatusOK if WriteHeader has not been called yet
 | ||||
| 		l.status = http.StatusOK | ||||
| 	} | ||||
| 	l.ExtractGAPMetadata() | ||||
| 	size, err := l.w.Write(b) | ||||
| 	l.size += size | ||||
| 	return size, err | ||||
|  | @ -65,7 +61,6 @@ func (l *responseLogger) Write(b []byte) (int, error) { | |||
| 
 | ||||
| // WriteHeader writes the status code for the Response
 | ||||
| func (l *responseLogger) WriteHeader(s int) { | ||||
| 	l.ExtractGAPMetadata() | ||||
| 	l.w.WriteHeader(s) | ||||
| 	l.status = s | ||||
| } | ||||
|  | @ -104,5 +99,7 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { | |||
| 	url := *req.URL | ||||
| 	responseLogger := &responseLogger{w: w} | ||||
| 	h.handler.ServeHTTP(responseLogger, req) | ||||
| 	logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) | ||||
| 
 | ||||
| 	authInfo, upstream := extractMetadata(w, req) | ||||
| 	logger.PrintReq(authInfo, upstream, req, url, t, responseLogger.Status(), responseLogger.Size()) | ||||
| } | ||||
|  |  | |||
|  | @ -6,7 +6,9 @@ import ( | |||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -102,7 +104,7 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { | |||
| 		logger.SetOutput(buf) | ||||
| 		logger.SetReqTemplate(test.Format) | ||||
| 		logger.SetExcludePaths(test.ExcludePaths) | ||||
| 		h := LoggingHandler(http.HandlerFunc(handler)) | ||||
| 		h := alice.New(middleware.NewScope(), LoggingHandler).Then(http.HandlerFunc(handler)) | ||||
| 
 | ||||
| 		r, _ := http.NewRequest("GET", test.Path, nil) | ||||
| 		r.RemoteAddr = "127.0.0.1" | ||||
|  |  | |||
|  | @ -21,4 +21,7 @@ type RequestScope struct { | |||
| 	// SessionRevalidated indicates whether the session has been revalidated since
 | ||||
| 	// it was loaded or not.
 | ||||
| 	SessionRevalidated bool | ||||
| 
 | ||||
| 	// Upstream indicates which (if any) upstream server the request was proxied to.
 | ||||
| 	Upstream string | ||||
| } | ||||
|  |  | |||
|  | @ -4,6 +4,8 @@ import ( | |||
| 	"net/http" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| ) | ||||
| 
 | ||||
| const fileScheme = "file" | ||||
|  | @ -37,6 +39,11 @@ type fileServer struct { | |||
| // ServeHTTP proxies requests to the upstream provider while signing the
 | ||||
| // request headers
 | ||||
| func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	rw.Header().Set("GAP-Upstream-Address", u.upstream) | ||||
| 	scope := middleware.GetRequestScope(req) | ||||
| 
 | ||||
| 	// If scope is nil, this will panic.
 | ||||
| 	// A scope should always be injected before this handler is called.
 | ||||
| 	scope.Upstream = u.upstream | ||||
| 
 | ||||
| 	u.handler.ServeHTTP(rw, req) | ||||
| } | ||||
|  |  | |||
|  | @ -7,6 +7,9 @@ import ( | |||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
|  | @ -16,6 +19,7 @@ var _ = Describe("File Server Suite", func() { | |||
| 	var dir string | ||||
| 	var handler http.Handler | ||||
| 	var id string | ||||
| 	var scope *middlewareapi.RequestScope | ||||
| 
 | ||||
| 	const ( | ||||
| 		foo          = "foo" | ||||
|  | @ -25,14 +29,24 @@ var _ = Describe("File Server Suite", func() { | |||
| 	) | ||||
| 
 | ||||
| 	BeforeEach(func() { | ||||
| 		// Generate a random id before each test to check the GAP-Upstream-Address
 | ||||
| 		// is being set correctly
 | ||||
| 		// Generate a random id before each test to check the upstream
 | ||||
| 		// is being set correctly in the scope
 | ||||
| 		idBytes := make([]byte, 16) | ||||
| 		_, err := io.ReadFull(rand.Reader, idBytes) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 		id = string(idBytes) | ||||
| 
 | ||||
| 		handler = newFileServer(id, "/files", filesDir) | ||||
| 		scope = nil | ||||
| 		// Extract the scope so that we can see that the upstream has been set
 | ||||
| 		// correctly
 | ||||
| 		extractScope := func(next http.Handler) http.Handler { | ||||
| 			return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 				scope = middleware.GetRequestScope(req) | ||||
| 				next.ServeHTTP(rw, req) | ||||
| 			}) | ||||
| 		} | ||||
| 
 | ||||
| 		handler = alice.New(middleware.NewScope(), extractScope).Then(newFileServer(id, "/files", filesDir)) | ||||
| 	}) | ||||
| 
 | ||||
| 	AfterEach(func() { | ||||
|  | @ -45,7 +59,7 @@ var _ = Describe("File Server Suite", func() { | |||
| 			rw := httptest.NewRecorder() | ||||
| 			handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | ||||
| 			Expect(scope.Upstream).To(Equal(id)) | ||||
| 			Expect(rw.Code).To(Equal(expectedResponseCode)) | ||||
| 			Expect(rw.Body.String()).To(Equal(expectedBody)) | ||||
| 		}, | ||||
|  |  | |||
|  | @ -10,6 +10,7 @@ import ( | |||
| 
 | ||||
| 	"github.com/mbland/hmacauth" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| 	"github.com/yhat/wsutil" | ||||
| ) | ||||
| 
 | ||||
|  | @ -77,7 +78,12 @@ type httpUpstreamProxy struct { | |||
| // ServeHTTP proxies requests to the upstream provider while signing the
 | ||||
| // request headers
 | ||||
| func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	rw.Header().Set("GAP-Upstream-Address", h.upstream) | ||||
| 	scope := middleware.GetRequestScope(req) | ||||
| 
 | ||||
| 	// If scope is nil, this will panic.
 | ||||
| 	// A scope should always be injected before this handler is called.
 | ||||
| 	scope.Upstream = h.upstream | ||||
| 
 | ||||
| 	if h.auth != nil { | ||||
| 		req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) | ||||
| 		h.auth.SignRequest(req) | ||||
|  |  | |||
|  | @ -13,7 +13,10 @@ import ( | |||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
|  | @ -35,6 +38,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 		body             []byte | ||||
| 		signatureData    *options.SignatureData | ||||
| 		existingHeaders  map[string]string | ||||
| 		expectedUpstream string | ||||
| 		expectedResponse testHTTPResponse | ||||
| 		errorHandler     ProxyErrorHandler | ||||
| 	} | ||||
|  | @ -66,10 +70,21 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			u, err := url.Parse(*in.serverAddr) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler) | ||||
| 			var scope *middlewareapi.RequestScope | ||||
| 			// Extract the scope so that we can see that the upstream has been set
 | ||||
| 			// correctly
 | ||||
| 			extractScope := func(next http.Handler) http.Handler { | ||||
| 				return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 					scope = middleware.GetRequestScope(req) | ||||
| 					next.ServeHTTP(rw, req) | ||||
| 				}) | ||||
| 			} | ||||
| 
 | ||||
| 			handler := alice.New(middleware.NewScope(), extractScope).Then(newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler)) | ||||
| 			handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 			Expect(rw.Code).To(Equal(in.expectedResponse.code)) | ||||
| 			Expect(scope.Upstream).To(Equal(in.expectedUpstream)) | ||||
| 
 | ||||
| 			// Delete extra headers that aren't relevant to tests
 | ||||
| 			testSanitizeResponseHeader(rw.Header()) | ||||
|  | @ -94,10 +109,10 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			method:           "GET", | ||||
| 			body:             []byte{}, | ||||
| 			errorHandler:     nil, | ||||
| 			expectedUpstream: "default", | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"default"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
|  | @ -117,10 +132,10 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			method:           "GET", | ||||
| 			body:             []byte{}, | ||||
| 			errorHandler:     nil, | ||||
| 			expectedUpstream: "encodedSlashes", | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"encodedSlashes"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
|  | @ -140,10 +155,10 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			method:           "POST", | ||||
| 			body:             []byte("body"), | ||||
| 			errorHandler:     nil, | ||||
| 			expectedUpstream: "requestWithBody", | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"requestWithBody"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
|  | @ -165,11 +180,10 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			method:           "GET", | ||||
| 			body:             []byte{}, | ||||
| 			errorHandler:     nil, | ||||
| 			expectedUpstream: "unavailableUpstream", | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code:    502, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"unavailableUpstream"}, | ||||
| 				}, | ||||
| 				header:  map[string][]string{}, | ||||
| 				request: testHTTPRequest{}, | ||||
| 			}, | ||||
| 		}), | ||||
|  | @ -183,11 +197,10 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 				rw.WriteHeader(502) | ||||
| 				rw.Write([]byte("error")) | ||||
| 			}, | ||||
| 			expectedUpstream: "withErrorHandler", | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code:    502, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"withErrorHandler"}, | ||||
| 				}, | ||||
| 				header:  map[string][]string{}, | ||||
| 				raw:     "error", | ||||
| 				request: testHTTPRequest{}, | ||||
| 			}, | ||||
|  | @ -203,11 +216,11 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 				Key:  "key", | ||||
| 			}, | ||||
| 			errorHandler:     nil, | ||||
| 			expectedUpstream: "withSignature", | ||||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					contentType: {applicationJSON}, | ||||
| 					gapUpstream: {"withSignature"}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
| 					Method: "GET", | ||||
|  | @ -229,6 +242,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			method:           "GET", | ||||
| 			body:             []byte{}, | ||||
| 			errorHandler:     nil, | ||||
| 			expectedUpstream: "existingHeaders", | ||||
| 			existingHeaders: map[string]string{ | ||||
| 				"Header1": "value1", | ||||
| 				"Header2": "value2", | ||||
|  | @ -236,7 +250,6 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			expectedResponse: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"existingHeaders"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
|  | @ -274,18 +287,21 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 		httpUpstream, ok := handler.(*httpUpstreamProxy) | ||||
| 		Expect(ok).To(BeTrue()) | ||||
| 
 | ||||
| 		var gotRequest *http.Request | ||||
| 		// Override the handler to just run the director and not actually send the request
 | ||||
| 		requestInterceptor := func(h http.Handler) http.Handler { | ||||
| 			return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { | ||||
| 				proxy, ok := h.(*httputil.ReverseProxy) | ||||
| 				Expect(ok).To(BeTrue()) | ||||
| 				proxy.Director(req) | ||||
| 
 | ||||
| 				gotRequest = req | ||||
| 			}) | ||||
| 		} | ||||
| 		httpUpstream.handler = requestInterceptor(httpUpstream.handler) | ||||
| 
 | ||||
| 		httpUpstream.ServeHTTP(rw, req) | ||||
| 		Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://"))) | ||||
| 		alice.New(middleware.NewScope()).Then(httpUpstream).ServeHTTP(rw, req) | ||||
| 		Expect(gotRequest.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://"))) | ||||
| 	}) | ||||
| 
 | ||||
| 	type newUpstreamTableInput struct { | ||||
|  | @ -368,6 +384,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 
 | ||||
| 	Context("with a websocket proxy", func() { | ||||
| 		var proxyServer *httptest.Server | ||||
| 		var scope *middlewareapi.RequestScope | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			flush := 1 * time.Second | ||||
|  | @ -382,7 +399,17 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			u, err := url.Parse(serverAddr) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			handler := newHTTPUpstreamProxy(upstream, u, nil, nil) | ||||
| 			scope = nil | ||||
| 			// Extract the scope so that we can see that the upstream has been set
 | ||||
| 			// correctly
 | ||||
| 			extractScope := func(next http.Handler) http.Handler { | ||||
| 				return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 					scope = middleware.GetRequestScope(req) | ||||
| 					next.ServeHTTP(rw, req) | ||||
| 				}) | ||||
| 			} | ||||
| 
 | ||||
| 			handler := alice.New(middleware.NewScope(), extractScope).Then(newHTTPUpstreamProxy(upstream, u, nil, nil)) | ||||
| 			proxyServer = httptest.NewServer(handler) | ||||
| 		}) | ||||
| 
 | ||||
|  | @ -414,7 +441,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 			response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(response.StatusCode).To(Equal(200)) | ||||
| 			Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) | ||||
| 			Expect(scope.Upstream).To(Equal("websocketProxy")) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
|  |  | |||
|  | @ -8,7 +8,10 @@ import ( | |||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
|  | @ -16,6 +19,7 @@ import ( | |||
| 
 | ||||
| var _ = Describe("Proxy Suite", func() { | ||||
| 	var upstreamServer http.Handler | ||||
| 	var scope *middlewareapi.RequestScope | ||||
| 
 | ||||
| 	BeforeEach(func() { | ||||
| 		sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} | ||||
|  | @ -56,12 +60,25 @@ var _ = Describe("Proxy Suite", func() { | |||
| 			}, | ||||
| 		} | ||||
| 
 | ||||
| 		upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) | ||||
| 		proxyServer, err := NewProxy(upstreams, sigData, errorHandler) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 		scope = nil | ||||
| 		// Extract the scope so that we can see that the upstream has been set
 | ||||
| 		// correctly
 | ||||
| 		extractScope := func(next http.Handler) http.Handler { | ||||
| 			return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 				scope = middleware.GetRequestScope(req) | ||||
| 				next.ServeHTTP(rw, req) | ||||
| 			}) | ||||
| 		} | ||||
| 
 | ||||
| 		upstreamServer = alice.New(middleware.NewScope(), extractScope).Then(proxyServer) | ||||
| 	}) | ||||
| 
 | ||||
| 	type proxyTableInput struct { | ||||
| 		target   string | ||||
| 		upstream string | ||||
| 		response testHTTPResponse | ||||
| 	} | ||||
| 
 | ||||
|  | @ -75,6 +92,7 @@ var _ = Describe("Proxy Suite", func() { | |||
| 			upstreamServer.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 			Expect(rw.Code).To(Equal(in.response.code)) | ||||
| 			Expect(scope.Upstream).To(Equal(in.upstream)) | ||||
| 
 | ||||
| 			// Delete extra headers that aren't relevant to tests
 | ||||
| 			testSanitizeResponseHeader(rw.Header()) | ||||
|  | @ -95,10 +113,10 @@ var _ = Describe("Proxy Suite", func() { | |||
| 		}, | ||||
| 		Entry("with a request to the HTTP service", &proxyTableInput{ | ||||
| 			target:   "http://example.localhost/http/1234", | ||||
| 			upstream: "http-backend", | ||||
| 			response: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"http-backend"}, | ||||
| 					contentType: {applicationJSON}, | ||||
| 				}, | ||||
| 				request: testHTTPRequest{ | ||||
|  | @ -116,32 +134,30 @@ var _ = Describe("Proxy Suite", func() { | |||
| 		}), | ||||
| 		Entry("with a request to the File backend", &proxyTableInput{ | ||||
| 			target:   "http://example.localhost/files/foo", | ||||
| 			upstream: "file-backend", | ||||
| 			response: testHTTPResponse{ | ||||
| 				code: 200, | ||||
| 				header: map[string][]string{ | ||||
| 					contentType: {textPlainUTF8}, | ||||
| 					gapUpstream: {"file-backend"}, | ||||
| 				}, | ||||
| 				raw: "foo", | ||||
| 			}, | ||||
| 		}), | ||||
| 		Entry("with a request to the Static backend", &proxyTableInput{ | ||||
| 			target:   "http://example.localhost/static/bar", | ||||
| 			upstream: "static-backend", | ||||
| 			response: testHTTPResponse{ | ||||
| 				code:   200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"static-backend"}, | ||||
| 				}, | ||||
| 				header: map[string][]string{}, | ||||
| 				raw:    "Authenticated", | ||||
| 			}, | ||||
| 		}), | ||||
| 		Entry("with a request to the bad HTTP backend", &proxyTableInput{ | ||||
| 			target:   "http://example.localhost/bad-http/bad", | ||||
| 			upstream: "bad-http-backend", | ||||
| 			response: testHTTPResponse{ | ||||
| 				code:   502, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"bad-http-backend"}, | ||||
| 				}, | ||||
| 				header: map[string][]string{}, | ||||
| 				// This tests the error handler
 | ||||
| 				raw: "Bad Gateway\nError proxying to upstream server\nprefix", | ||||
| 			}, | ||||
|  | @ -159,11 +175,10 @@ var _ = Describe("Proxy Suite", func() { | |||
| 		}), | ||||
| 		Entry("with a request to the to backend registered to a single path", &proxyTableInput{ | ||||
| 			target:   "http://example.localhost/single-path", | ||||
| 			upstream: "single-path-backend", | ||||
| 			response: testHTTPResponse{ | ||||
| 				code:   200, | ||||
| 				header: map[string][]string{ | ||||
| 					gapUpstream: {"single-path-backend"}, | ||||
| 				}, | ||||
| 				header: map[string][]string{}, | ||||
| 				raw:    "Authenticated", | ||||
| 			}, | ||||
| 		}), | ||||
|  |  | |||
|  | @ -3,6 +3,8 @@ package upstream | |||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| ) | ||||
| 
 | ||||
| const defaultStaticResponseCode = 200 | ||||
|  | @ -24,7 +26,12 @@ type staticResponseHandler struct { | |||
| 
 | ||||
| // ServeHTTP serves a static response.
 | ||||
| func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	rw.Header().Set("GAP-Upstream-Address", s.upstream) | ||||
| 	scope := middleware.GetRequestScope(req) | ||||
| 
 | ||||
| 	// If scope is nil, this will panic.
 | ||||
| 	// A scope should always be injected before this handler is called.
 | ||||
| 	scope.Upstream = s.upstream | ||||
| 
 | ||||
| 	rw.WriteHeader(s.code) | ||||
| 	fmt.Fprintf(rw, "Authenticated") | ||||
| } | ||||
|  |  | |||
|  | @ -6,6 +6,9 @@ import ( | |||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
|  | @ -16,8 +19,8 @@ var _ = Describe("Static Response Suite", func() { | |||
| 	var id string | ||||
| 
 | ||||
| 	BeforeEach(func() { | ||||
| 		// Generate a random id before each test to check the GAP-Upstream-Address
 | ||||
| 		// is being set correctly
 | ||||
| 		// Generate a random id before each test to check the upstream
 | ||||
| 		// is being set correctly in the scope
 | ||||
| 		idBytes := make([]byte, 16) | ||||
| 		_, err := io.ReadFull(rand.Reader, idBytes) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
|  | @ -37,13 +40,24 @@ var _ = Describe("Static Response Suite", func() { | |||
| 			if in.staticCode != 0 { | ||||
| 				code = &in.staticCode | ||||
| 			} | ||||
| 			handler := newStaticResponseHandler(id, code) | ||||
| 
 | ||||
| 			var scope *middlewareapi.RequestScope | ||||
| 			// Extract the scope so that we can see that the upstream has been set
 | ||||
| 			// correctly
 | ||||
| 			extractScope := func(next http.Handler) http.Handler { | ||||
| 				return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 					scope = middleware.GetRequestScope(req) | ||||
| 					next.ServeHTTP(rw, req) | ||||
| 				}) | ||||
| 			} | ||||
| 
 | ||||
| 			handler := alice.New(middleware.NewScope(), extractScope).Then(newStaticResponseHandler(id, code)) | ||||
| 
 | ||||
| 			req := httptest.NewRequest("", in.requestPath, nil) | ||||
| 			rw := httptest.NewRecorder() | ||||
| 			handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | ||||
| 			Expect(scope.Upstream).To(Equal(id)) | ||||
| 			Expect(rw.Code).To(Equal(in.expectedCode)) | ||||
| 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | ||||
| 		}, | ||||
|  |  | |||
|  | @ -58,7 +58,6 @@ const ( | |||
| 	acceptEncoding  = "Accept-Encoding" | ||||
| 	applicationJSON = "application/json" | ||||
| 	textPlainUTF8   = "text/plain; charset=utf-8" | ||||
| 	gapUpstream     = "Gap-Upstream-Address" | ||||
| 	gapAuth         = "Gap-Auth" | ||||
| 	gapSignature    = "Gap-Signature" | ||||
| ) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue