Add tests for upstream package
This commit is contained in:
		
							parent
							
								
									fa8e1ee033
								
							
						
					
					
						commit
						5b95ed3033
					
				|  | @ -0,0 +1,58 @@ | ||||||
|  | package upstream | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"os" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("File Server Suite", func() { | ||||||
|  | 	var dir string | ||||||
|  | 	var handler http.Handler | ||||||
|  | 	var id string | ||||||
|  | 
 | ||||||
|  | 	const ( | ||||||
|  | 		foo          = "foo" | ||||||
|  | 		bar          = "bar" | ||||||
|  | 		baz          = "baz" | ||||||
|  | 		pageNotFound = "404 page not found\n" | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		// Generate a random id before each test to check the GAP-Upstream-Address
 | ||||||
|  | 		// is being set correctly
 | ||||||
|  | 		idBytes := make([]byte, 16) | ||||||
|  | 		_, err := io.ReadFull(rand.Reader, idBytes) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 		id = string(idBytes) | ||||||
|  | 
 | ||||||
|  | 		handler = newFileServer(id, "/files", filesDir) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	AfterEach(func() { | ||||||
|  | 		Expect(os.RemoveAll(dir)).To(Succeed()) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	DescribeTable("fileServer ServeHTTP", | ||||||
|  | 		func(requestPath string, expectedResponseCode int, expectedBody string) { | ||||||
|  | 			req := httptest.NewRequest("", requestPath, nil) | ||||||
|  | 			rw := httptest.NewRecorder() | ||||||
|  | 			handler.ServeHTTP(rw, req) | ||||||
|  | 
 | ||||||
|  | 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | ||||||
|  | 			Expect(rw.Code).To(Equal(expectedResponseCode)) | ||||||
|  | 			Expect(rw.Body.String()).To(Equal(expectedBody)) | ||||||
|  | 		}, | ||||||
|  | 		Entry("for file foo", "/files/foo", 200, foo), | ||||||
|  | 		Entry("for file bar", "/files/bar", 200, bar), | ||||||
|  | 		Entry("for file foo/baz", "/files/subdir/baz", 200, baz), | ||||||
|  | 		Entry("for a non-existent file inside the path", "/files/baz", 404, pageNotFound), | ||||||
|  | 		Entry("for a non-existent file oustide the path", "/baz", 404, pageNotFound), | ||||||
|  | 	) | ||||||
|  | }) | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"net/http/httputil" | 	"net/http/httputil" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||||
|  | @ -96,7 +97,12 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr | ||||||
| 	proxy := httputil.NewSingleHostReverseProxy(target) | 	proxy := httputil.NewSingleHostReverseProxy(target) | ||||||
| 
 | 
 | ||||||
| 	// Configure options on the SingleHostReverseProxy
 | 	// Configure options on the SingleHostReverseProxy
 | ||||||
|  | 	if upstream.FlushInterval != nil { | ||||||
| 		proxy.FlushInterval = *upstream.FlushInterval | 		proxy.FlushInterval = *upstream.FlushInterval | ||||||
|  | 	} else { | ||||||
|  | 		proxy.FlushInterval = 1 * time.Second | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if upstream.InsecureSkipTLSVerify { | 	if upstream.InsecureSkipTLSVerify { | ||||||
| 		proxy.Transport = &http.Transport{ | 		proxy.Transport = &http.Transport{ | ||||||
| 			TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, | 			TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, | ||||||
|  |  | ||||||
|  | @ -0,0 +1,417 @@ | ||||||
|  | package upstream | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"net/http/httputil" | ||||||
|  | 	"net/url" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | 	"golang.org/x/net/websocket" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("HTTP Upstream Suite", func() { | ||||||
|  | 
 | ||||||
|  | 	const flushInterval5s = 5 * time.Second | ||||||
|  | 	const flushInterval1s = 1 * time.Second | ||||||
|  | 
 | ||||||
|  | 	type httpUpstreamTableInput struct { | ||||||
|  | 		id               string | ||||||
|  | 		serverAddr       *string | ||||||
|  | 		target           string | ||||||
|  | 		method           string | ||||||
|  | 		body             []byte | ||||||
|  | 		signatureData    *options.SignatureData | ||||||
|  | 		existingHeaders  map[string]string | ||||||
|  | 		expectedResponse testHTTPResponse | ||||||
|  | 		errorHandler     ProxyErrorHandler | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DescribeTable("HTTP Upstream ServeHTTP", | ||||||
|  | 		func(in *httpUpstreamTableInput) { | ||||||
|  | 			buf := bytes.NewBuffer(in.body) | ||||||
|  | 			req := httptest.NewRequest(in.method, in.target, buf) | ||||||
|  | 			// Don't mock the remote Address
 | ||||||
|  | 			req.RemoteAddr = "" | ||||||
|  | 
 | ||||||
|  | 			for key, value := range in.existingHeaders { | ||||||
|  | 				req.Header.Add(key, value) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			rw := httptest.NewRecorder() | ||||||
|  | 
 | ||||||
|  | 			flush := 1 * time.Second | ||||||
|  | 			upstream := options.Upstream{ | ||||||
|  | 				ID:                    in.id, | ||||||
|  | 				PassHostHeader:        true, | ||||||
|  | 				ProxyWebSockets:       false, | ||||||
|  | 				InsecureSkipTLSVerify: false, | ||||||
|  | 				FlushInterval:         &flush, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			Expect(in.serverAddr).ToNot(BeNil()) | ||||||
|  | 			u, err := url.Parse(*in.serverAddr) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler) | ||||||
|  | 			handler.ServeHTTP(rw, req) | ||||||
|  | 
 | ||||||
|  | 			Expect(rw.Code).To(Equal(in.expectedResponse.code)) | ||||||
|  | 
 | ||||||
|  | 			// Delete extra headers that aren't relevant to tests
 | ||||||
|  | 			testSanitizeResponseHeader(rw.Header()) | ||||||
|  | 			Expect(rw.Header()).To(Equal(in.expectedResponse.header)) | ||||||
|  | 
 | ||||||
|  | 			body := rw.Body.Bytes() | ||||||
|  | 			if in.expectedResponse.raw != "" || rw.Code != http.StatusOK { | ||||||
|  | 				Expect(string(body)).To(Equal(in.expectedResponse.raw)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Compare the reflected request to the upstream
 | ||||||
|  | 			request := testHTTPRequest{} | ||||||
|  | 			Expect(json.Unmarshal(body, &request)).To(Succeed()) | ||||||
|  | 			testSanitizeRequestHeader(request.Header) | ||||||
|  | 			Expect(request).To(Equal(in.expectedResponse.request)) | ||||||
|  | 		}, | ||||||
|  | 		Entry("request a path on the server", &httpUpstreamTableInput{ | ||||||
|  | 			id:           "default", | ||||||
|  | 			serverAddr:   &serverAddr, | ||||||
|  | 			target:       "http://example.localhost/foo", | ||||||
|  | 			method:       "GET", | ||||||
|  | 			body:         []byte{}, | ||||||
|  | 			errorHandler: nil, | ||||||
|  | 			expectedResponse: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"default"}, | ||||||
|  | 					contentType: {applicationJSON}, | ||||||
|  | 				}, | ||||||
|  | 				request: testHTTPRequest{ | ||||||
|  | 					Method:     "GET", | ||||||
|  | 					URL:        "http://example.localhost/foo", | ||||||
|  | 					Header:     map[string][]string{}, | ||||||
|  | 					Body:       []byte{}, | ||||||
|  | 					Host:       "example.localhost", | ||||||
|  | 					RequestURI: "http://example.localhost/foo", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("request a path with encoded slashes", &httpUpstreamTableInput{ | ||||||
|  | 			id:           "encodedSlashes", | ||||||
|  | 			serverAddr:   &serverAddr, | ||||||
|  | 			target:       "http://example.localhost/foo%2fbar/?baz=1", | ||||||
|  | 			method:       "GET", | ||||||
|  | 			body:         []byte{}, | ||||||
|  | 			errorHandler: nil, | ||||||
|  | 			expectedResponse: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"encodedSlashes"}, | ||||||
|  | 					contentType: {applicationJSON}, | ||||||
|  | 				}, | ||||||
|  | 				request: testHTTPRequest{ | ||||||
|  | 					Method:     "GET", | ||||||
|  | 					URL:        "http://example.localhost/foo%2fbar/?baz=1", | ||||||
|  | 					Header:     map[string][]string{}, | ||||||
|  | 					Body:       []byte{}, | ||||||
|  | 					Host:       "example.localhost", | ||||||
|  | 					RequestURI: "http://example.localhost/foo%2fbar/?baz=1", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("when the request has a body", &httpUpstreamTableInput{ | ||||||
|  | 			id:           "requestWithBody", | ||||||
|  | 			serverAddr:   &serverAddr, | ||||||
|  | 			target:       "http://example.localhost/withBody", | ||||||
|  | 			method:       "POST", | ||||||
|  | 			body:         []byte("body"), | ||||||
|  | 			errorHandler: nil, | ||||||
|  | 			expectedResponse: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"requestWithBody"}, | ||||||
|  | 					contentType: {applicationJSON}, | ||||||
|  | 				}, | ||||||
|  | 				request: testHTTPRequest{ | ||||||
|  | 					Method: "POST", | ||||||
|  | 					URL:    "http://example.localhost/withBody", | ||||||
|  | 					Header: map[string][]string{ | ||||||
|  | 						contentLength: {"4"}, | ||||||
|  | 					}, | ||||||
|  | 					Body:       []byte("body"), | ||||||
|  | 					Host:       "example.localhost", | ||||||
|  | 					RequestURI: "http://example.localhost/withBody", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("when the upstream is unavailable", &httpUpstreamTableInput{ | ||||||
|  | 			id:           "unavailableUpstream", | ||||||
|  | 			serverAddr:   &invalidServer, | ||||||
|  | 			target:       "http://example.localhost/unavailableUpstream", | ||||||
|  | 			method:       "GET", | ||||||
|  | 			body:         []byte{}, | ||||||
|  | 			errorHandler: nil, | ||||||
|  | 			expectedResponse: testHTTPResponse{ | ||||||
|  | 				code: 502, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"unavailableUpstream"}, | ||||||
|  | 				}, | ||||||
|  | 				request: testHTTPRequest{}, | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{ | ||||||
|  | 			id:         "withErrorHandler", | ||||||
|  | 			serverAddr: &invalidServer, | ||||||
|  | 			target:     "http://example.localhost/withErrorHandler", | ||||||
|  | 			method:     "GET", | ||||||
|  | 			body:       []byte{}, | ||||||
|  | 			errorHandler: func(rw http.ResponseWriter, _ *http.Request, _ error) { | ||||||
|  | 				rw.WriteHeader(502) | ||||||
|  | 				rw.Write([]byte("error")) | ||||||
|  | 			}, | ||||||
|  | 			expectedResponse: testHTTPResponse{ | ||||||
|  | 				code: 502, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"withErrorHandler"}, | ||||||
|  | 				}, | ||||||
|  | 				raw:     "error", | ||||||
|  | 				request: testHTTPRequest{}, | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a signature", &httpUpstreamTableInput{ | ||||||
|  | 			id:         "withSignature", | ||||||
|  | 			serverAddr: &serverAddr, | ||||||
|  | 			target:     "http://example.localhost/withSignature", | ||||||
|  | 			method:     "GET", | ||||||
|  | 			body:       []byte{}, | ||||||
|  | 			signatureData: &options.SignatureData{ | ||||||
|  | 				Hash: crypto.SHA256, | ||||||
|  | 				Key:  "key", | ||||||
|  | 			}, | ||||||
|  | 			errorHandler: nil, | ||||||
|  | 			expectedResponse: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					contentType: {applicationJSON}, | ||||||
|  | 					gapUpstream: {"withSignature"}, | ||||||
|  | 				}, | ||||||
|  | 				request: testHTTPRequest{ | ||||||
|  | 					Method: "GET", | ||||||
|  | 					URL:    "http://example.localhost/withSignature", | ||||||
|  | 					Header: map[string][]string{ | ||||||
|  | 						gapAuth:      {""}, | ||||||
|  | 						gapSignature: {"sha256 osMWI8Rr0Zr5HgNq6wakrgJITVJQMmFN1fXCesrqrmM="}, | ||||||
|  | 					}, | ||||||
|  | 					Body:       []byte{}, | ||||||
|  | 					Host:       "example.localhost", | ||||||
|  | 					RequestURI: "http://example.localhost/withSignature", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with existing headers", &httpUpstreamTableInput{ | ||||||
|  | 			id:           "existingHeaders", | ||||||
|  | 			serverAddr:   &serverAddr, | ||||||
|  | 			target:       "http://example.localhost/existingHeaders", | ||||||
|  | 			method:       "GET", | ||||||
|  | 			body:         []byte{}, | ||||||
|  | 			errorHandler: nil, | ||||||
|  | 			existingHeaders: map[string]string{ | ||||||
|  | 				"Header1": "value1", | ||||||
|  | 				"Header2": "value2", | ||||||
|  | 			}, | ||||||
|  | 			expectedResponse: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"existingHeaders"}, | ||||||
|  | 					contentType: {applicationJSON}, | ||||||
|  | 				}, | ||||||
|  | 				request: testHTTPRequest{ | ||||||
|  | 					Method: "GET", | ||||||
|  | 					URL:    "http://example.localhost/existingHeaders", | ||||||
|  | 					Header: map[string][]string{ | ||||||
|  | 						"Header1": {"value1"}, | ||||||
|  | 						"Header2": {"value2"}, | ||||||
|  | 					}, | ||||||
|  | 					Body:       []byte{}, | ||||||
|  | 					Host:       "example.localhost", | ||||||
|  | 					RequestURI: "http://example.localhost/existingHeaders", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	It("ServeHTTP, when not passing a host header", func() { | ||||||
|  | 		req := httptest.NewRequest("", "http://example.localhost/foo", nil) | ||||||
|  | 		rw := httptest.NewRecorder() | ||||||
|  | 
 | ||||||
|  | 		flush := 1 * time.Second | ||||||
|  | 		upstream := options.Upstream{ | ||||||
|  | 			ID:                    "noPassHost", | ||||||
|  | 			PassHostHeader:        false, | ||||||
|  | 			ProxyWebSockets:       false, | ||||||
|  | 			InsecureSkipTLSVerify: false, | ||||||
|  | 			FlushInterval:         &flush, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		u, err := url.Parse(serverAddr) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		handler := newHTTPUpstreamProxy(upstream, u, nil, nil) | ||||||
|  | 		httpUpstream, ok := handler.(*httpUpstreamProxy) | ||||||
|  | 		Expect(ok).To(BeTrue()) | ||||||
|  | 
 | ||||||
|  | 		// Override the handler to just run the director and not actually send the request
 | ||||||
|  | 		requestInterceptor := func(h http.Handler) http.Handler { | ||||||
|  | 			return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { | ||||||
|  | 				proxy, ok := h.(*httputil.ReverseProxy) | ||||||
|  | 				Expect(ok).To(BeTrue()) | ||||||
|  | 				proxy.Director(req) | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 		httpUpstream.handler = requestInterceptor(httpUpstream.handler) | ||||||
|  | 
 | ||||||
|  | 		httpUpstream.ServeHTTP(rw, req) | ||||||
|  | 		Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://"))) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	type newUpstreamTableInput struct { | ||||||
|  | 		proxyWebSockets bool | ||||||
|  | 		flushInterval   time.Duration | ||||||
|  | 		skipVerify      bool | ||||||
|  | 		sigData         *options.SignatureData | ||||||
|  | 		errorHandler    func(http.ResponseWriter, *http.Request, error) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DescribeTable("newHTTPUpstreamProxy", | ||||||
|  | 		func(in *newUpstreamTableInput) { | ||||||
|  | 			u, err := url.Parse("http://upstream:1234") | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			upstream := options.Upstream{ | ||||||
|  | 				ID:                    "foo123", | ||||||
|  | 				FlushInterval:         &in.flushInterval, | ||||||
|  | 				InsecureSkipTLSVerify: in.skipVerify, | ||||||
|  | 				ProxyWebSockets:       in.proxyWebSockets, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler) | ||||||
|  | 			upstreamProxy, ok := handler.(*httpUpstreamProxy) | ||||||
|  | 			Expect(ok).To(BeTrue()) | ||||||
|  | 
 | ||||||
|  | 			Expect(upstreamProxy.auth != nil).To(Equal(in.sigData != nil)) | ||||||
|  | 			Expect(upstreamProxy.wsHandler != nil).To(Equal(in.proxyWebSockets)) | ||||||
|  | 			Expect(upstreamProxy.upstream).To(Equal(upstream.ID)) | ||||||
|  | 			Expect(upstreamProxy.handler).ToNot(BeNil()) | ||||||
|  | 
 | ||||||
|  | 			proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy) | ||||||
|  | 			Expect(ok).To(BeTrue()) | ||||||
|  | 			Expect(proxy.FlushInterval).To(Equal(in.flushInterval)) | ||||||
|  | 			Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil)) | ||||||
|  | 			if in.skipVerify { | ||||||
|  | 				Expect(proxy.Transport).To(Equal(&http.Transport{ | ||||||
|  | 					TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, | ||||||
|  | 				})) | ||||||
|  | 			} | ||||||
|  | 		}, | ||||||
|  | 		Entry("with proxy websockets", &newUpstreamTableInput{ | ||||||
|  | 			proxyWebSockets: true, | ||||||
|  | 			flushInterval:   flushInterval1s, | ||||||
|  | 			skipVerify:      false, | ||||||
|  | 			sigData:         nil, | ||||||
|  | 			errorHandler:    nil, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a non standard flush interval", &newUpstreamTableInput{ | ||||||
|  | 			proxyWebSockets: false, | ||||||
|  | 			flushInterval:   flushInterval5s, | ||||||
|  | 			skipVerify:      false, | ||||||
|  | 			sigData:         nil, | ||||||
|  | 			errorHandler:    nil, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a InsecureSkipTLSVerify", &newUpstreamTableInput{ | ||||||
|  | 			proxyWebSockets: false, | ||||||
|  | 			flushInterval:   flushInterval1s, | ||||||
|  | 			skipVerify:      true, | ||||||
|  | 			sigData:         nil, | ||||||
|  | 			errorHandler:    nil, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a SignatureData", &newUpstreamTableInput{ | ||||||
|  | 			proxyWebSockets: false, | ||||||
|  | 			flushInterval:   flushInterval1s, | ||||||
|  | 			skipVerify:      false, | ||||||
|  | 			sigData:         &options.SignatureData{Hash: crypto.SHA256, Key: "secret"}, | ||||||
|  | 			errorHandler:    nil, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with an error handler", &newUpstreamTableInput{ | ||||||
|  | 			proxyWebSockets: false, | ||||||
|  | 			flushInterval:   flushInterval1s, | ||||||
|  | 			skipVerify:      false, | ||||||
|  | 			sigData:         nil, | ||||||
|  | 			errorHandler: func(rw http.ResponseWriter, req *http.Request, arg3 error) { | ||||||
|  | 				rw.WriteHeader(502) | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	Context("with a websocket proxy", func() { | ||||||
|  | 		var proxyServer *httptest.Server | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			flush := 1 * time.Second | ||||||
|  | 			upstream := options.Upstream{ | ||||||
|  | 				ID:                    "websocketProxy", | ||||||
|  | 				PassHostHeader:        true, | ||||||
|  | 				ProxyWebSockets:       true, | ||||||
|  | 				InsecureSkipTLSVerify: false, | ||||||
|  | 				FlushInterval:         &flush, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			u, err := url.Parse(serverAddr) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			handler := newHTTPUpstreamProxy(upstream, u, nil, nil) | ||||||
|  | 			proxyServer = httptest.NewServer(handler) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		AfterEach(func() { | ||||||
|  | 			proxyServer.Close() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("will proxy websockets", func() { | ||||||
|  | 			origin := "http://example.localhost" | ||||||
|  | 			message := "Hello, world!" | ||||||
|  | 
 | ||||||
|  | 			proxyURL, err := url.Parse(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			wsAddr := fmt.Sprintf("ws://%s/", proxyURL.Host) | ||||||
|  | 			ws, err := websocket.Dial(wsAddr, "", origin) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed()) | ||||||
|  | 			var response testWebSocketResponse | ||||||
|  | 			Expect(websocket.JSON.Receive(ws, &response)).To(Succeed()) | ||||||
|  | 			Expect(response).To(Equal(testWebSocketResponse{ | ||||||
|  | 				Message: message, | ||||||
|  | 				Origin:  origin, | ||||||
|  | 			})) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("will proxy HTTP requests", func() { | ||||||
|  | 			response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(response.StatusCode).To(Equal(200)) | ||||||
|  | 			Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,182 @@ | ||||||
|  | package upstream | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"html/template" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Proxy Suite", func() { | ||||||
|  | 	var upstreamServer http.Handler | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} | ||||||
|  | 
 | ||||||
|  | 		tmpl, err := template.New("").Parse("{{ .Title }}\n{{ .Message }}\n{{ .ProxyPrefix }}") | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 		errorHandler := NewProxyErrorHandler(tmpl, "prefix") | ||||||
|  | 
 | ||||||
|  | 		ok := http.StatusOK | ||||||
|  | 
 | ||||||
|  | 		upstreams := options.Upstreams{ | ||||||
|  | 			{ | ||||||
|  | 				ID:   "http-backend", | ||||||
|  | 				Path: "/http/", | ||||||
|  | 				URI:  serverAddr, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				ID:   "file-backend", | ||||||
|  | 				Path: "/files/", | ||||||
|  | 				URI:  fmt.Sprintf("file:///%s", filesDir), | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				ID:         "static-backend", | ||||||
|  | 				Path:       "/static/", | ||||||
|  | 				Static:     true, | ||||||
|  | 				StaticCode: &ok, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				ID:   "bad-http-backend", | ||||||
|  | 				Path: "/bad-http/", | ||||||
|  | 				URI:  "http://::1", | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				ID:         "single-path-backend", | ||||||
|  | 				Path:       "/single-path", | ||||||
|  | 				Static:     true, | ||||||
|  | 				StaticCode: &ok, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	type proxyTableInput struct { | ||||||
|  | 		target   string | ||||||
|  | 		response testHTTPResponse | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DescribeTable("Proxy ServerHTTP", | ||||||
|  | 		func(in *proxyTableInput) { | ||||||
|  | 			req := httptest.NewRequest("", in.target, nil) | ||||||
|  | 			rw := httptest.NewRecorder() | ||||||
|  | 			// Don't mock the remote Address
 | ||||||
|  | 			req.RemoteAddr = "" | ||||||
|  | 
 | ||||||
|  | 			upstreamServer.ServeHTTP(rw, req) | ||||||
|  | 
 | ||||||
|  | 			Expect(rw.Code).To(Equal(in.response.code)) | ||||||
|  | 
 | ||||||
|  | 			// Delete extra headers that aren't relevant to tests
 | ||||||
|  | 			testSanitizeResponseHeader(rw.Header()) | ||||||
|  | 			Expect(rw.Header()).To(Equal(in.response.header)) | ||||||
|  | 
 | ||||||
|  | 			body := rw.Body.Bytes() | ||||||
|  | 			// If the raw body is set, check that, else check the Request object
 | ||||||
|  | 			if in.response.raw != "" { | ||||||
|  | 				Expect(string(body)).To(Equal(in.response.raw)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Compare the reflected request to the upstream
 | ||||||
|  | 			request := testHTTPRequest{} | ||||||
|  | 			Expect(json.Unmarshal(body, &request)).To(Succeed()) | ||||||
|  | 			testSanitizeRequestHeader(request.Header) | ||||||
|  | 			Expect(request).To(Equal(in.response.request)) | ||||||
|  | 		}, | ||||||
|  | 		Entry("with a request to the HTTP service", &proxyTableInput{ | ||||||
|  | 			target: "http://example.localhost/http/1234", | ||||||
|  | 			response: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"http-backend"}, | ||||||
|  | 					contentType: {applicationJSON}, | ||||||
|  | 				}, | ||||||
|  | 				request: testHTTPRequest{ | ||||||
|  | 					Method: "GET", | ||||||
|  | 					URL:    "http://example.localhost/http/1234", | ||||||
|  | 					Header: map[string][]string{ | ||||||
|  | 						"Gap-Auth":      {""}, | ||||||
|  | 						"Gap-Signature": {"sha256 ofB1u6+FhEUbFLc3/uGbJVkl7GaN4egFqVvyO3+2I1w="}, | ||||||
|  | 					}, | ||||||
|  | 					Body:       []byte{}, | ||||||
|  | 					Host:       "example.localhost", | ||||||
|  | 					RequestURI: "http://example.localhost/http/1234", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a request to the File backend", &proxyTableInput{ | ||||||
|  | 			target: "http://example.localhost/files/foo", | ||||||
|  | 			response: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					contentType: {textPlainUTF8}, | ||||||
|  | 					gapUpstream: {"file-backend"}, | ||||||
|  | 				}, | ||||||
|  | 				raw: "foo", | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a request to the Static backend", &proxyTableInput{ | ||||||
|  | 			target: "http://example.localhost/static/bar", | ||||||
|  | 			response: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"static-backend"}, | ||||||
|  | 				}, | ||||||
|  | 				raw: "Authenticated", | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a request to the bad HTTP backend", &proxyTableInput{ | ||||||
|  | 			target: "http://example.localhost/bad-http/bad", | ||||||
|  | 			response: testHTTPResponse{ | ||||||
|  | 				code: 502, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"bad-http-backend"}, | ||||||
|  | 				}, | ||||||
|  | 				// This tests the error handler
 | ||||||
|  | 				raw: "Bad Gateway\nError proxying to upstream server\nprefix", | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a request to the to an unregistered path", &proxyTableInput{ | ||||||
|  | 			target: "http://example.localhost/unregistered", | ||||||
|  | 			response: testHTTPResponse{ | ||||||
|  | 				code: 404, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					"X-Content-Type-Options": {"nosniff"}, | ||||||
|  | 					contentType:              {textPlainUTF8}, | ||||||
|  | 				}, | ||||||
|  | 				raw: "404 page not found\n", | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a request to the to backend registered to a single path", &proxyTableInput{ | ||||||
|  | 			target: "http://example.localhost/single-path", | ||||||
|  | 			response: testHTTPResponse{ | ||||||
|  | 				code: 200, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					gapUpstream: {"single-path-backend"}, | ||||||
|  | 				}, | ||||||
|  | 				raw: "Authenticated", | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ | ||||||
|  | 			target: "http://example.localhost/single-path/unregistered", | ||||||
|  | 			response: testHTTPResponse{ | ||||||
|  | 				code: 404, | ||||||
|  | 				header: map[string][]string{ | ||||||
|  | 					"X-Content-Type-Options": {"nosniff"}, | ||||||
|  | 					contentType:              {textPlainUTF8}, | ||||||
|  | 				}, | ||||||
|  | 				raw: "404 page not found\n", | ||||||
|  | 			}, | ||||||
|  | 		}), | ||||||
|  | 	) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,81 @@ | ||||||
|  | package upstream | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Static Response Suite", func() { | ||||||
|  | 	const authenticated = "Authenticated" | ||||||
|  | 	var id string | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		// Generate a random id before each test to check the GAP-Upstream-Address
 | ||||||
|  | 		// is being set correctly
 | ||||||
|  | 		idBytes := make([]byte, 16) | ||||||
|  | 		_, err := io.ReadFull(rand.Reader, idBytes) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 		id = string(idBytes) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	type serveHTTPTableInput struct { | ||||||
|  | 		requestPath  string | ||||||
|  | 		staticCode   int | ||||||
|  | 		expectedBody string | ||||||
|  | 		expectedCode int | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DescribeTable("staticResponse ServeHTTP", | ||||||
|  | 		func(in *serveHTTPTableInput) { | ||||||
|  | 			var code *int | ||||||
|  | 			if in.staticCode != 0 { | ||||||
|  | 				code = &in.staticCode | ||||||
|  | 			} | ||||||
|  | 			handler := newStaticResponseHandler(id, code) | ||||||
|  | 
 | ||||||
|  | 			req := httptest.NewRequest("", in.requestPath, nil) | ||||||
|  | 			rw := httptest.NewRecorder() | ||||||
|  | 			handler.ServeHTTP(rw, req) | ||||||
|  | 
 | ||||||
|  | 			Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) | ||||||
|  | 			Expect(rw.Code).To(Equal(in.expectedCode)) | ||||||
|  | 			Expect(rw.Body.String()).To(Equal(in.expectedBody)) | ||||||
|  | 		}, | ||||||
|  | 		Entry("with no given code", &serveHTTPTableInput{ | ||||||
|  | 			requestPath:  "/", | ||||||
|  | 			staticCode:   0, // Placeholder for nil
 | ||||||
|  | 			expectedBody: authenticated, | ||||||
|  | 			expectedCode: http.StatusOK, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with status OK", &serveHTTPTableInput{ | ||||||
|  | 			requestPath:  "/abc", | ||||||
|  | 			staticCode:   http.StatusOK, | ||||||
|  | 			expectedBody: authenticated, | ||||||
|  | 			expectedCode: http.StatusOK, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with status NoContent", &serveHTTPTableInput{ | ||||||
|  | 			requestPath:  "/def", | ||||||
|  | 			staticCode:   http.StatusNoContent, | ||||||
|  | 			expectedBody: authenticated, | ||||||
|  | 			expectedCode: http.StatusNoContent, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with status NotFound", &serveHTTPTableInput{ | ||||||
|  | 			requestPath:  "/ghi", | ||||||
|  | 			staticCode:   http.StatusNotFound, | ||||||
|  | 			expectedBody: authenticated, | ||||||
|  | 			expectedCode: http.StatusNotFound, | ||||||
|  | 		}), | ||||||
|  | 		Entry("with status Teapot", &serveHTTPTableInput{ | ||||||
|  | 			requestPath:  "/jkl", | ||||||
|  | 			staticCode:   http.StatusTeapot, | ||||||
|  | 			expectedBody: authenticated, | ||||||
|  | 			expectedCode: http.StatusTeapot, | ||||||
|  | 		}), | ||||||
|  | 	) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,180 @@ | ||||||
|  | package upstream | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"log" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"os" | ||||||
|  | 	"path" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | 	"golang.org/x/net/websocket" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	filesDir      string | ||||||
|  | 	server        *httptest.Server | ||||||
|  | 	serverAddr    string | ||||||
|  | 	invalidServer = "http://::1" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestUpstreamSuite(t *testing.T) { | ||||||
|  | 	logger.SetOutput(GinkgoWriter) | ||||||
|  | 	log.SetOutput(GinkgoWriter) | ||||||
|  | 
 | ||||||
|  | 	RegisterFailHandler(Fail) | ||||||
|  | 	RunSpecs(t, "Upstream Suite") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var _ = BeforeSuite(func() { | ||||||
|  | 	// Set up files for serving via file servers
 | ||||||
|  | 	dir, err := ioutil.TempDir("", "oauth2-proxy-upstream-suite") | ||||||
|  | 	Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 	Expect(ioutil.WriteFile(path.Join(dir, "foo"), []byte("foo"), 0644)).To(Succeed()) | ||||||
|  | 	Expect(ioutil.WriteFile(path.Join(dir, "bar"), []byte("bar"), 0644)).To(Succeed()) | ||||||
|  | 	Expect(os.Mkdir(path.Join(dir, "subdir"), os.ModePerm)).To(Succeed()) | ||||||
|  | 	Expect(ioutil.WriteFile(path.Join(dir, "subdir", "baz"), []byte("baz"), 0644)).To(Succeed()) | ||||||
|  | 	filesDir = dir | ||||||
|  | 
 | ||||||
|  | 	// Set up a webserver that reflects requests
 | ||||||
|  | 	server = httptest.NewServer(&testHTTPUpstream{}) | ||||||
|  | 	serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String()) | ||||||
|  | }) | ||||||
|  | 
 | ||||||
|  | var _ = AfterSuite(func() { | ||||||
|  | 	server.Close() | ||||||
|  | 	Expect(os.RemoveAll(filesDir)).To(Succeed()) | ||||||
|  | }) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	contentType     = "Content-Type" | ||||||
|  | 	contentLength   = "Content-Length" | ||||||
|  | 	acceptEncoding  = "Accept-Encoding" | ||||||
|  | 	applicationJSON = "application/json" | ||||||
|  | 	textPlainUTF8   = "text/plain; charset=utf-8" | ||||||
|  | 	gapUpstream     = "Gap-Upstream-Address" | ||||||
|  | 	gapAuth         = "Gap-Auth" | ||||||
|  | 	gapSignature    = "Gap-Signature" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // testHTTPResponse is a struct used for checking responses in table tests
 | ||||||
|  | type testHTTPResponse struct { | ||||||
|  | 	code    int | ||||||
|  | 	header  http.Header | ||||||
|  | 	raw     string | ||||||
|  | 	request testHTTPRequest | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // testHTTPRequest is a struct used to capture the state of a request made to
 | ||||||
|  | // an upstream during a test
 | ||||||
|  | type testHTTPRequest struct { | ||||||
|  | 	Method     string | ||||||
|  | 	URL        string | ||||||
|  | 	Header     http.Header | ||||||
|  | 	Body       []byte | ||||||
|  | 	Host       string | ||||||
|  | 	RequestURI string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type testWebSocketResponse struct { | ||||||
|  | 	Message string | ||||||
|  | 	Origin  string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type testHTTPUpstream struct{} | ||||||
|  | 
 | ||||||
|  | func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 	if req.Header.Get("Upgrade") == "websocket" { | ||||||
|  | 		t.websocketHandler().ServeHTTP(rw, req) | ||||||
|  | 	} else { | ||||||
|  | 		t.serveHTTP(rw, req) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (t *testHTTPUpstream) serveHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 	request, err := toTestHTTPRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.writeError(rw, err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	data, err := json.Marshal(request) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.writeError(rw, err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rw.Header().Set("Content-Type", "application/json") | ||||||
|  | 	rw.Write(data) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (t *testHTTPUpstream) websocketHandler() http.Handler { | ||||||
|  | 	return websocket.Handler(func(ws *websocket.Conn) { | ||||||
|  | 		defer ws.Close() | ||||||
|  | 		var data []byte | ||||||
|  | 		err := websocket.Message.Receive(ws, &data) | ||||||
|  | 		if err != nil { | ||||||
|  | 			websocket.Message.Send(ws, []byte(err.Error())) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		wsResponse := testWebSocketResponse{ | ||||||
|  | 			Message: string(data), | ||||||
|  | 			Origin:  ws.Request().Header.Get("Origin"), | ||||||
|  | 		} | ||||||
|  | 		err = websocket.JSON.Send(ws, wsResponse) | ||||||
|  | 		if err != nil { | ||||||
|  | 			websocket.Message.Send(ws, []byte(err.Error())) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) { | ||||||
|  | 	rw.WriteHeader(500) | ||||||
|  | 	if err != nil { | ||||||
|  | 		rw.Write([]byte(err.Error())) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) { | ||||||
|  | 	requestBody := []byte{} | ||||||
|  | 	if req.Body != http.NoBody { | ||||||
|  | 		var err error | ||||||
|  | 		requestBody, err = ioutil.ReadAll(req.Body) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return testHTTPRequest{}, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return testHTTPRequest{ | ||||||
|  | 		Method:     req.Method, | ||||||
|  | 		URL:        req.URL.String(), | ||||||
|  | 		Header:     req.Header, | ||||||
|  | 		Body:       requestBody, | ||||||
|  | 		Host:       req.Host, | ||||||
|  | 		RequestURI: req.RequestURI, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // String headers added to the response that we do not want to test
 | ||||||
|  | func testSanitizeResponseHeader(h http.Header) { | ||||||
|  | 	// From HTTP responses
 | ||||||
|  | 	h.Del("Date") | ||||||
|  | 	h.Del(contentLength) | ||||||
|  | 
 | ||||||
|  | 	// From File responses
 | ||||||
|  | 	h.Del("Accept-Ranges") | ||||||
|  | 	h.Del("Last-Modified") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Strip the accept header that is added by the HTTP Transport
 | ||||||
|  | func testSanitizeRequestHeader(h http.Header) { | ||||||
|  | 	h.Del(acceptEncoding) | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue