diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d92a1bd..17b81383 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v6.0.0 +- [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed) - [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed) - [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves) - [#619](https://github.com/oauth2-proxy/oauth2-proxy/pull/619) Improve Redirect to HTTPs behaviour (@JoelSpeed) diff --git a/pkg/apis/options/upstreams.go b/pkg/apis/options/upstreams.go new file mode 100644 index 00000000..5a8eebe5 --- /dev/null +++ b/pkg/apis/options/upstreams.go @@ -0,0 +1,60 @@ +package options + +import "time" + +// Upstreams is a collection of definitions for upstream servers. +type Upstreams []Upstream + +// Upstream represents the configuration for an upstream server. +// Requests will be proxied to this upstream if the path matches the request path. +type Upstream struct { + // ID should be a unique identifier for the upstream. + // This value is required for all upstreams. + ID string `json:"id"` + + // Path is used to map requests to the upstream server. + // The closest match will take precedence and all Paths must be unique. + Path string `json:"path"` + + // The URI of the upstream server. This may be an HTTP(S) server of a File + // based URL. It may include a path, in which case all requests will be served + // under that path. + // Eg: + // - http://localhost:8080 + // - https://service.localhost + // - https://service.localhost/path + // - file://host/path + // If the URI's path is "/base" and the incoming request was for "/dir", + // the upstream request will be for "/base/dir". + URI string `json:"uri"` + + // InsecureSkipTLSVerify will skip TLS verification of upstream HTTPS hosts. + // This option is insecure and will allow potential Man-In-The-Middle attacks + // betweem OAuth2 Proxy and the usptream server. + // Defaults to false. + InsecureSkipTLSVerify bool `json:"insecureSkipTLSVerify"` + + // Static will make all requests to this upstream have a static response. + // The response will have a body of "Authenticated" and a response code + // matching StaticCode. + // If StaticCode is not set, the response will return a 200 response. + Static bool `json:"static"` + + // StaticCode determines the response code for the Static response. + // This option can only be used with Static enabled. + StaticCode *int `json:"staticCode,omitempty"` + + // FlushInterval is the period between flushing the response buffer when + // streaming response from the upstream. + // Defaults to 1 second. + FlushInterval *time.Duration `json:"flushInterval,omitempty"` + + // PassHostHeader determines whether the request host header should be proxied + // to the upstream server. + // Defaults to true. + PassHostHeader bool `json:"passHostHeader"` + + // ProxyWebSockets enables proxying of websockets to upstream servers + // Defaults to true. + ProxyWebSockets bool `json:"proxyWebSockets"` +} diff --git a/pkg/upstream/file.go b/pkg/upstream/file.go new file mode 100644 index 00000000..7f67edb0 --- /dev/null +++ b/pkg/upstream/file.go @@ -0,0 +1,42 @@ +package upstream + +import ( + "net/http" + "runtime" + "strings" +) + +const fileScheme = "file" + +// newFileServer creates a new fileServer that can serve requests +// to a file system location. +func newFileServer(id, path, fileSystemPath string) http.Handler { + return &fileServer{ + upstream: id, + handler: newFileServerForPath(path, fileSystemPath), + } +} + +// newFileServerForPath creates a http.Handler to serve files from the filesystem +func newFileServerForPath(path string, filesystemPath string) http.Handler { + // Windows fileSSystemPath will be be prefixed with `/`, eg`/C:/..., + // if they were parsed by url.Parse` + if runtime.GOOS == "windows" { + filesystemPath = strings.TrimPrefix(filesystemPath, "/") + } + + return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) +} + +// fileServer represents a single filesystem upstream proxy +type fileServer struct { + upstream string + handler http.Handler +} + +// ServeHTTP proxies requests to the upstream provider while signing the +// request headers +func (u *fileServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("GAP-Upstream-Address", u.upstream) + u.handler.ServeHTTP(rw, req) +} diff --git a/pkg/upstream/file_test.go b/pkg/upstream/file_test.go new file mode 100644 index 00000000..2da1f078 --- /dev/null +++ b/pkg/upstream/file_test.go @@ -0,0 +1,58 @@ +package upstream + +import ( + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("File Server Suite", func() { + var dir string + var handler http.Handler + var id string + + const ( + foo = "foo" + bar = "bar" + baz = "baz" + pageNotFound = "404 page not found\n" + ) + + BeforeEach(func() { + // Generate a random id before each test to check the GAP-Upstream-Address + // is being set correctly + idBytes := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, idBytes) + Expect(err).ToNot(HaveOccurred()) + id = string(idBytes) + + handler = newFileServer(id, "/files", filesDir) + }) + + AfterEach(func() { + Expect(os.RemoveAll(dir)).To(Succeed()) + }) + + DescribeTable("fileServer ServeHTTP", + func(requestPath string, expectedResponseCode int, expectedBody string) { + req := httptest.NewRequest("", requestPath, nil) + rw := httptest.NewRecorder() + handler.ServeHTTP(rw, req) + + Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) + Expect(rw.Code).To(Equal(expectedResponseCode)) + Expect(rw.Body.String()).To(Equal(expectedBody)) + }, + Entry("for file foo", "/files/foo", 200, foo), + Entry("for file bar", "/files/bar", 200, bar), + Entry("for file foo/baz", "/files/subdir/baz", 200, baz), + Entry("for a non-existent file inside the path", "/files/baz", 404, pageNotFound), + Entry("for a non-existent file oustide the path", "/baz", 404, pageNotFound), + ) +}) diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go new file mode 100644 index 00000000..fa7b2a8a --- /dev/null +++ b/pkg/upstream/http.go @@ -0,0 +1,163 @@ +package upstream + +import ( + "crypto/tls" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/mbland/hmacauth" + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + "github.com/yhat/wsutil" +) + +const ( + // SignatureHeader is the name of the request header containing the GAP Signature + // Part of hmacauth + SignatureHeader = "GAP-Signature" + + httpScheme = "http" + httpsScheme = "https" +) + +// SignatureHeaders contains the headers to be signed by the hmac algorithm +// Part of hmacauth +var SignatureHeaders = []string{ + "Content-Length", + "Content-Md5", + "Content-Type", + "Date", + "Authorization", + "X-Forwarded-User", + "X-Forwarded-Email", + "X-Forwarded-Preferred-User", + "X-Forwarded-Access-Token", + "Cookie", + "Gap-Auth", +} + +// newHTTPUpstreamProxy creates a new httpUpstreamProxy that can serve requests +// to a single upstream host. +func newHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, errorHandler ProxyErrorHandler) http.Handler { + // Set path to empty so that request paths start at the server root + u.Path = "" + + // Create a ReverseProxy + proxy := newReverseProxy(u, upstream, errorHandler) + + // Set up a WebSocket proxy if required + var wsProxy http.Handler + if upstream.ProxyWebSockets { + wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify) + } + + var auth hmacauth.HmacAuth + if sigData != nil { + auth = hmacauth.NewHmacAuth(sigData.Hash, []byte(sigData.Key), SignatureHeader, SignatureHeaders) + } + + return &httpUpstreamProxy{ + upstream: upstream.ID, + handler: proxy, + wsHandler: wsProxy, + auth: auth, + } +} + +// httpUpstreamProxy represents a single HTTP(S) upstream proxy +type httpUpstreamProxy struct { + upstream string + handler http.Handler + wsHandler http.Handler + auth hmacauth.HmacAuth +} + +// ServeHTTP proxies requests to the upstream provider while signing the +// request headers +func (h *httpUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("GAP-Upstream-Address", h.upstream) + if h.auth != nil { + req.Header.Set("GAP-Auth", rw.Header().Get("GAP-Auth")) + h.auth.SignRequest(req) + } + if h.wsHandler != nil && strings.EqualFold(req.Header.Get("Connection"), "upgrade") && req.Header.Get("Upgrade") == "websocket" { + h.wsHandler.ServeHTTP(rw, req) + } else { + h.handler.ServeHTTP(rw, req) + } +} + +// newReverseProxy creates a new reverse proxy for proxying requests to upstream +// servers based on the upstream configuration provided. +// The proxy should render an error page if there are failures connecting to the +// upstream server. +func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler ProxyErrorHandler) http.Handler { + proxy := httputil.NewSingleHostReverseProxy(target) + + // Configure options on the SingleHostReverseProxy + if upstream.FlushInterval != nil { + proxy.FlushInterval = *upstream.FlushInterval + } else { + proxy.FlushInterval = 1 * time.Second + } + + if upstream.InsecureSkipTLSVerify { + proxy.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + } + + // Set the request director based on the PassHostHeader option + if !upstream.PassHostHeader { + setProxyUpstreamHostHeader(proxy, target) + } else { + setProxyDirector(proxy) + } + + // Set the error handler so that upstream connection failures render the + // error page instead of sending a empty response + if errorHandler != nil { + proxy.ErrorHandler = errorHandler + } + return proxy +} + +// setProxyUpstreamHostHeader sets the proxy.Director so that upstream requests +// receive a host header matching the target URL. +func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { + director := proxy.Director + proxy.Director = func(req *http.Request) { + director(req) + // use RequestURI so that we aren't unescaping encoded slashes in the request path + req.Host = target.Host + req.URL.Opaque = req.RequestURI + req.URL.RawQuery = "" + } +} + +// setProxyDirector sets the proxy.Director so that request URIs are escaped +// when proxying to usptream servers. +func setProxyDirector(proxy *httputil.ReverseProxy) { + director := proxy.Director + proxy.Director = func(req *http.Request) { + director(req) + // use RequestURI so that we aren't unescaping encoded slashes in the request path + req.URL.Opaque = req.RequestURI + req.URL.RawQuery = "" + } +} + +// newWebSocketReverseProxy creates a new reverse proxy for proxying websocket connections. +func newWebSocketReverseProxy(u *url.URL, skipTLSVerify bool) http.Handler { + // This should create the correct scheme for insecure vs secure connections + wsScheme := "ws" + strings.TrimPrefix(u.Scheme, "http") + wsURL := &url.URL{Scheme: wsScheme, Host: u.Host} + + wsProxy := wsutil.NewSingleHostReverseProxy(wsURL) + if skipTLSVerify { + wsProxy.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + return wsProxy +} diff --git a/pkg/upstream/http_test.go b/pkg/upstream/http_test.go new file mode 100644 index 00000000..8c601880 --- /dev/null +++ b/pkg/upstream/http_test.go @@ -0,0 +1,417 @@ +package upstream + +import ( + "bytes" + "crypto" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" + "golang.org/x/net/websocket" +) + +var _ = Describe("HTTP Upstream Suite", func() { + + const flushInterval5s = 5 * time.Second + const flushInterval1s = 1 * time.Second + + type httpUpstreamTableInput struct { + id string + serverAddr *string + target string + method string + body []byte + signatureData *options.SignatureData + existingHeaders map[string]string + expectedResponse testHTTPResponse + errorHandler ProxyErrorHandler + } + + DescribeTable("HTTP Upstream ServeHTTP", + func(in *httpUpstreamTableInput) { + buf := bytes.NewBuffer(in.body) + req := httptest.NewRequest(in.method, in.target, buf) + // Don't mock the remote Address + req.RemoteAddr = "" + + for key, value := range in.existingHeaders { + req.Header.Add(key, value) + } + + rw := httptest.NewRecorder() + + flush := 1 * time.Second + upstream := options.Upstream{ + ID: in.id, + PassHostHeader: true, + ProxyWebSockets: false, + InsecureSkipTLSVerify: false, + FlushInterval: &flush, + } + + Expect(in.serverAddr).ToNot(BeNil()) + u, err := url.Parse(*in.serverAddr) + Expect(err).ToNot(HaveOccurred()) + + handler := newHTTPUpstreamProxy(upstream, u, in.signatureData, in.errorHandler) + handler.ServeHTTP(rw, req) + + Expect(rw.Code).To(Equal(in.expectedResponse.code)) + + // Delete extra headers that aren't relevant to tests + testSanitizeResponseHeader(rw.Header()) + Expect(rw.Header()).To(Equal(in.expectedResponse.header)) + + body := rw.Body.Bytes() + if in.expectedResponse.raw != "" || rw.Code != http.StatusOK { + Expect(string(body)).To(Equal(in.expectedResponse.raw)) + return + } + + // Compare the reflected request to the upstream + request := testHTTPRequest{} + Expect(json.Unmarshal(body, &request)).To(Succeed()) + testSanitizeRequestHeader(request.Header) + Expect(request).To(Equal(in.expectedResponse.request)) + }, + Entry("request a path on the server", &httpUpstreamTableInput{ + id: "default", + serverAddr: &serverAddr, + target: "http://example.localhost/foo", + method: "GET", + body: []byte{}, + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"default"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/foo", + Header: map[string][]string{}, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/foo", + }, + }, + }), + Entry("request a path with encoded slashes", &httpUpstreamTableInput{ + id: "encodedSlashes", + serverAddr: &serverAddr, + target: "http://example.localhost/foo%2fbar/?baz=1", + method: "GET", + body: []byte{}, + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"encodedSlashes"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/foo%2fbar/?baz=1", + Header: map[string][]string{}, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/foo%2fbar/?baz=1", + }, + }, + }), + Entry("when the request has a body", &httpUpstreamTableInput{ + id: "requestWithBody", + serverAddr: &serverAddr, + target: "http://example.localhost/withBody", + method: "POST", + body: []byte("body"), + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"requestWithBody"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "POST", + URL: "http://example.localhost/withBody", + Header: map[string][]string{ + contentLength: {"4"}, + }, + Body: []byte("body"), + Host: "example.localhost", + RequestURI: "http://example.localhost/withBody", + }, + }, + }), + Entry("when the upstream is unavailable", &httpUpstreamTableInput{ + id: "unavailableUpstream", + serverAddr: &invalidServer, + target: "http://example.localhost/unavailableUpstream", + method: "GET", + body: []byte{}, + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 502, + header: map[string][]string{ + gapUpstream: {"unavailableUpstream"}, + }, + request: testHTTPRequest{}, + }, + }), + Entry("when the upstream is unavailable and an error handler is set", &httpUpstreamTableInput{ + id: "withErrorHandler", + serverAddr: &invalidServer, + target: "http://example.localhost/withErrorHandler", + method: "GET", + body: []byte{}, + errorHandler: func(rw http.ResponseWriter, _ *http.Request, _ error) { + rw.WriteHeader(502) + rw.Write([]byte("error")) + }, + expectedResponse: testHTTPResponse{ + code: 502, + header: map[string][]string{ + gapUpstream: {"withErrorHandler"}, + }, + raw: "error", + request: testHTTPRequest{}, + }, + }), + Entry("with a signature", &httpUpstreamTableInput{ + id: "withSignature", + serverAddr: &serverAddr, + target: "http://example.localhost/withSignature", + method: "GET", + body: []byte{}, + signatureData: &options.SignatureData{ + Hash: crypto.SHA256, + Key: "key", + }, + errorHandler: nil, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + contentType: {applicationJSON}, + gapUpstream: {"withSignature"}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/withSignature", + Header: map[string][]string{ + gapAuth: {""}, + gapSignature: {"sha256 osMWI8Rr0Zr5HgNq6wakrgJITVJQMmFN1fXCesrqrmM="}, + }, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/withSignature", + }, + }, + }), + Entry("with existing headers", &httpUpstreamTableInput{ + id: "existingHeaders", + serverAddr: &serverAddr, + target: "http://example.localhost/existingHeaders", + method: "GET", + body: []byte{}, + errorHandler: nil, + existingHeaders: map[string]string{ + "Header1": "value1", + "Header2": "value2", + }, + expectedResponse: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"existingHeaders"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/existingHeaders", + Header: map[string][]string{ + "Header1": {"value1"}, + "Header2": {"value2"}, + }, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/existingHeaders", + }, + }, + }), + ) + + It("ServeHTTP, when not passing a host header", func() { + req := httptest.NewRequest("", "http://example.localhost/foo", nil) + rw := httptest.NewRecorder() + + flush := 1 * time.Second + upstream := options.Upstream{ + ID: "noPassHost", + PassHostHeader: false, + ProxyWebSockets: false, + InsecureSkipTLSVerify: false, + FlushInterval: &flush, + } + + u, err := url.Parse(serverAddr) + Expect(err).ToNot(HaveOccurred()) + + handler := newHTTPUpstreamProxy(upstream, u, nil, nil) + httpUpstream, ok := handler.(*httpUpstreamProxy) + Expect(ok).To(BeTrue()) + + // Override the handler to just run the director and not actually send the request + requestInterceptor := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + proxy, ok := h.(*httputil.ReverseProxy) + Expect(ok).To(BeTrue()) + proxy.Director(req) + }) + } + httpUpstream.handler = requestInterceptor(httpUpstream.handler) + + httpUpstream.ServeHTTP(rw, req) + Expect(req.Host).To(Equal(strings.TrimPrefix(serverAddr, "http://"))) + }) + + type newUpstreamTableInput struct { + proxyWebSockets bool + flushInterval time.Duration + skipVerify bool + sigData *options.SignatureData + errorHandler func(http.ResponseWriter, *http.Request, error) + } + + DescribeTable("newHTTPUpstreamProxy", + func(in *newUpstreamTableInput) { + u, err := url.Parse("http://upstream:1234") + Expect(err).ToNot(HaveOccurred()) + + upstream := options.Upstream{ + ID: "foo123", + FlushInterval: &in.flushInterval, + InsecureSkipTLSVerify: in.skipVerify, + ProxyWebSockets: in.proxyWebSockets, + } + + handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler) + upstreamProxy, ok := handler.(*httpUpstreamProxy) + Expect(ok).To(BeTrue()) + + Expect(upstreamProxy.auth != nil).To(Equal(in.sigData != nil)) + Expect(upstreamProxy.wsHandler != nil).To(Equal(in.proxyWebSockets)) + Expect(upstreamProxy.upstream).To(Equal(upstream.ID)) + Expect(upstreamProxy.handler).ToNot(BeNil()) + + proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy) + Expect(ok).To(BeTrue()) + Expect(proxy.FlushInterval).To(Equal(in.flushInterval)) + Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil)) + if in.skipVerify { + Expect(proxy.Transport).To(Equal(&http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + })) + } + }, + Entry("with proxy websockets", &newUpstreamTableInput{ + proxyWebSockets: true, + flushInterval: flushInterval1s, + skipVerify: false, + sigData: nil, + errorHandler: nil, + }), + Entry("with a non standard flush interval", &newUpstreamTableInput{ + proxyWebSockets: false, + flushInterval: flushInterval5s, + skipVerify: false, + sigData: nil, + errorHandler: nil, + }), + Entry("with a InsecureSkipTLSVerify", &newUpstreamTableInput{ + proxyWebSockets: false, + flushInterval: flushInterval1s, + skipVerify: true, + sigData: nil, + errorHandler: nil, + }), + Entry("with a SignatureData", &newUpstreamTableInput{ + proxyWebSockets: false, + flushInterval: flushInterval1s, + skipVerify: false, + sigData: &options.SignatureData{Hash: crypto.SHA256, Key: "secret"}, + errorHandler: nil, + }), + Entry("with an error handler", &newUpstreamTableInput{ + proxyWebSockets: false, + flushInterval: flushInterval1s, + skipVerify: false, + sigData: nil, + errorHandler: func(rw http.ResponseWriter, req *http.Request, arg3 error) { + rw.WriteHeader(502) + }, + }), + ) + + Context("with a websocket proxy", func() { + var proxyServer *httptest.Server + + BeforeEach(func() { + flush := 1 * time.Second + upstream := options.Upstream{ + ID: "websocketProxy", + PassHostHeader: true, + ProxyWebSockets: true, + InsecureSkipTLSVerify: false, + FlushInterval: &flush, + } + + u, err := url.Parse(serverAddr) + Expect(err).ToNot(HaveOccurred()) + + handler := newHTTPUpstreamProxy(upstream, u, nil, nil) + proxyServer = httptest.NewServer(handler) + }) + + AfterEach(func() { + proxyServer.Close() + }) + + It("will proxy websockets", func() { + origin := "http://example.localhost" + message := "Hello, world!" + + proxyURL, err := url.Parse(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) + Expect(err).ToNot(HaveOccurred()) + + wsAddr := fmt.Sprintf("ws://%s/", proxyURL.Host) + ws, err := websocket.Dial(wsAddr, "", origin) + Expect(err).ToNot(HaveOccurred()) + + Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed()) + var response testWebSocketResponse + Expect(websocket.JSON.Receive(ws, &response)).To(Succeed()) + Expect(response).To(Equal(testWebSocketResponse{ + Message: message, + Origin: origin, + })) + }) + + It("will proxy HTTP requests", func() { + response, err := http.Get(fmt.Sprintf("http://%s", proxyServer.Listener.Addr().String())) + Expect(err).ToNot(HaveOccurred()) + Expect(response.StatusCode).To(Equal(200)) + Expect(response.Header.Get(gapUpstream)).To(Equal("websocketProxy")) + }) + }) +}) diff --git a/pkg/upstream/proxy.go b/pkg/upstream/proxy.go new file mode 100644 index 00000000..6c7c581b --- /dev/null +++ b/pkg/upstream/proxy.go @@ -0,0 +1,90 @@ +package upstream + +import ( + "fmt" + "html/template" + "net/http" + "net/url" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +// ProxyErrorHandler is a function that will be used to render error pages when +// HTTP proxies fail to connect to upstream servers. +type ProxyErrorHandler func(http.ResponseWriter, *http.Request, error) + +// NewProxy creates a new multiUpstreamProxy that can serve requests directed to +// multiple upstreams. +func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, errorHandler ProxyErrorHandler) (http.Handler, error) { + m := &multiUpstreamProxy{ + serveMux: http.NewServeMux(), + } + + for _, upstream := range upstreams { + if upstream.Static { + m.registerStaticResponseHandler(upstream) + continue + } + + u, err := url.Parse(upstream.URI) + if err != nil { + return nil, fmt.Errorf("error parsing URI for upstream %q: %w", upstream.ID, err) + } + switch u.Scheme { + case fileScheme: + m.registerFileServer(upstream, u) + case httpScheme, httpsScheme: + m.registerHTTPUpstreamProxy(upstream, u, sigData, errorHandler) + default: + return nil, fmt.Errorf("unknown scheme for upstream %q: %q", upstream.ID, u.Scheme) + } + } + return m, nil +} + +// multiUpstreamProxy will serve requests directed to multiple upstream servers +// registered in the serverMux. +type multiUpstreamProxy struct { + serveMux *http.ServeMux +} + +// ServerHTTP handles HTTP requests. +func (m *multiUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + m.serveMux.ServeHTTP(rw, req) +} + +// registerStaticResponseHandler registers a static response handler with at the given path. +func (m *multiUpstreamProxy) registerStaticResponseHandler(upstream options.Upstream) { + m.serveMux.Handle(upstream.Path, newStaticResponseHandler(upstream.ID, upstream.StaticCode)) +} + +// registerFileServer registers a new fileServer based on the configuration given. +func (m *multiUpstreamProxy) registerFileServer(upstream options.Upstream, u *url.URL) { + logger.Printf("mapping path %q => file system %q", upstream.Path, u.Path) + m.serveMux.Handle(upstream.Path, newFileServer(upstream.ID, upstream.Path, u.Path)) +} + +// registerHTTPUpstreamProxy registers a new httpUpstreamProxy based on the configuration given. +func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, errorHandler ProxyErrorHandler) { + logger.Printf("mapping path %q => upstream %q", upstream.Path, upstream.URI) + m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, errorHandler)) +} + +// NewProxyErrorHandler creates a ProxyErrorHandler using the template given. +func NewProxyErrorHandler(errorTemplate *template.Template, proxyPrefix string) ProxyErrorHandler { + return func(rw http.ResponseWriter, req *http.Request, proxyErr error) { + logger.Printf("Error proxying to upstream server: %v", proxyErr) + rw.WriteHeader(http.StatusBadGateway) + data := struct { + Title string + Message string + ProxyPrefix string + }{ + Title: "Bad Gateway", + Message: "Error proxying to upstream server", + ProxyPrefix: proxyPrefix, + } + errorTemplate.Execute(rw, data) + } +} diff --git a/pkg/upstream/proxy_test.go b/pkg/upstream/proxy_test.go new file mode 100644 index 00000000..945fb665 --- /dev/null +++ b/pkg/upstream/proxy_test.go @@ -0,0 +1,182 @@ +package upstream + +import ( + "crypto" + "encoding/json" + "fmt" + "html/template" + "net/http" + "net/http/httptest" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Proxy Suite", func() { + var upstreamServer http.Handler + + BeforeEach(func() { + sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} + + tmpl, err := template.New("").Parse("{{ .Title }}\n{{ .Message }}\n{{ .ProxyPrefix }}") + Expect(err).ToNot(HaveOccurred()) + errorHandler := NewProxyErrorHandler(tmpl, "prefix") + + ok := http.StatusOK + + upstreams := options.Upstreams{ + { + ID: "http-backend", + Path: "/http/", + URI: serverAddr, + }, + { + ID: "file-backend", + Path: "/files/", + URI: fmt.Sprintf("file:///%s", filesDir), + }, + { + ID: "static-backend", + Path: "/static/", + Static: true, + StaticCode: &ok, + }, + { + ID: "bad-http-backend", + Path: "/bad-http/", + URI: "http://::1", + }, + { + ID: "single-path-backend", + Path: "/single-path", + Static: true, + StaticCode: &ok, + }, + } + + upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) + Expect(err).ToNot(HaveOccurred()) + }) + + type proxyTableInput struct { + target string + response testHTTPResponse + } + + DescribeTable("Proxy ServerHTTP", + func(in *proxyTableInput) { + req := httptest.NewRequest("", in.target, nil) + rw := httptest.NewRecorder() + // Don't mock the remote Address + req.RemoteAddr = "" + + upstreamServer.ServeHTTP(rw, req) + + Expect(rw.Code).To(Equal(in.response.code)) + + // Delete extra headers that aren't relevant to tests + testSanitizeResponseHeader(rw.Header()) + Expect(rw.Header()).To(Equal(in.response.header)) + + body := rw.Body.Bytes() + // If the raw body is set, check that, else check the Request object + if in.response.raw != "" { + Expect(string(body)).To(Equal(in.response.raw)) + return + } + + // Compare the reflected request to the upstream + request := testHTTPRequest{} + Expect(json.Unmarshal(body, &request)).To(Succeed()) + testSanitizeRequestHeader(request.Header) + Expect(request).To(Equal(in.response.request)) + }, + Entry("with a request to the HTTP service", &proxyTableInput{ + target: "http://example.localhost/http/1234", + response: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"http-backend"}, + contentType: {applicationJSON}, + }, + request: testHTTPRequest{ + Method: "GET", + URL: "http://example.localhost/http/1234", + Header: map[string][]string{ + "Gap-Auth": {""}, + "Gap-Signature": {"sha256 ofB1u6+FhEUbFLc3/uGbJVkl7GaN4egFqVvyO3+2I1w="}, + }, + Body: []byte{}, + Host: "example.localhost", + RequestURI: "http://example.localhost/http/1234", + }, + }, + }), + Entry("with a request to the File backend", &proxyTableInput{ + target: "http://example.localhost/files/foo", + response: testHTTPResponse{ + code: 200, + header: map[string][]string{ + contentType: {textPlainUTF8}, + gapUpstream: {"file-backend"}, + }, + raw: "foo", + }, + }), + Entry("with a request to the Static backend", &proxyTableInput{ + target: "http://example.localhost/static/bar", + response: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"static-backend"}, + }, + raw: "Authenticated", + }, + }), + Entry("with a request to the bad HTTP backend", &proxyTableInput{ + target: "http://example.localhost/bad-http/bad", + response: testHTTPResponse{ + code: 502, + header: map[string][]string{ + gapUpstream: {"bad-http-backend"}, + }, + // This tests the error handler + raw: "Bad Gateway\nError proxying to upstream server\nprefix", + }, + }), + Entry("with a request to the to an unregistered path", &proxyTableInput{ + target: "http://example.localhost/unregistered", + response: testHTTPResponse{ + code: 404, + header: map[string][]string{ + "X-Content-Type-Options": {"nosniff"}, + contentType: {textPlainUTF8}, + }, + raw: "404 page not found\n", + }, + }), + Entry("with a request to the to backend registered to a single path", &proxyTableInput{ + target: "http://example.localhost/single-path", + response: testHTTPResponse{ + code: 200, + header: map[string][]string{ + gapUpstream: {"single-path-backend"}, + }, + raw: "Authenticated", + }, + }), + Entry("with a request to the to a subpath of a backend registered to a single path", &proxyTableInput{ + target: "http://example.localhost/single-path/unregistered", + response: testHTTPResponse{ + code: 404, + header: map[string][]string{ + "X-Content-Type-Options": {"nosniff"}, + contentType: {textPlainUTF8}, + }, + raw: "404 page not found\n", + }, + }), + ) +}) diff --git a/pkg/upstream/static.go b/pkg/upstream/static.go new file mode 100644 index 00000000..0a061421 --- /dev/null +++ b/pkg/upstream/static.go @@ -0,0 +1,34 @@ +package upstream + +import ( + "fmt" + "net/http" +) + +const defaultStaticResponseCode = 200 + +// newStaticResponseHandler creates a new staticResponseHandler that serves a +// a static response code. +func newStaticResponseHandler(upstream string, code *int) http.Handler { + if code == nil { + c := defaultStaticResponseCode + code = &c + } + return &staticResponseHandler{ + code: *code, + upstream: upstream, + } +} + +// staticResponseHandler responds with a static response with the given response code. +type staticResponseHandler struct { + code int + upstream string +} + +// ServeHTTP serves a static response. +func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("GAP-Upstream-Address", s.upstream) + rw.WriteHeader(s.code) + fmt.Fprintf(rw, "Authenticated") +} diff --git a/pkg/upstream/static_test.go b/pkg/upstream/static_test.go new file mode 100644 index 00000000..1b7309f7 --- /dev/null +++ b/pkg/upstream/static_test.go @@ -0,0 +1,81 @@ +package upstream + +import ( + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Static Response Suite", func() { + const authenticated = "Authenticated" + var id string + + BeforeEach(func() { + // Generate a random id before each test to check the GAP-Upstream-Address + // is being set correctly + idBytes := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, idBytes) + Expect(err).ToNot(HaveOccurred()) + id = string(idBytes) + }) + + type serveHTTPTableInput struct { + requestPath string + staticCode int + expectedBody string + expectedCode int + } + + DescribeTable("staticResponse ServeHTTP", + func(in *serveHTTPTableInput) { + var code *int + if in.staticCode != 0 { + code = &in.staticCode + } + handler := newStaticResponseHandler(id, code) + + req := httptest.NewRequest("", in.requestPath, nil) + rw := httptest.NewRecorder() + handler.ServeHTTP(rw, req) + + Expect(rw.Header().Get("GAP-Upstream-Address")).To(Equal(id)) + Expect(rw.Code).To(Equal(in.expectedCode)) + Expect(rw.Body.String()).To(Equal(in.expectedBody)) + }, + Entry("with no given code", &serveHTTPTableInput{ + requestPath: "/", + staticCode: 0, // Placeholder for nil + expectedBody: authenticated, + expectedCode: http.StatusOK, + }), + Entry("with status OK", &serveHTTPTableInput{ + requestPath: "/abc", + staticCode: http.StatusOK, + expectedBody: authenticated, + expectedCode: http.StatusOK, + }), + Entry("with status NoContent", &serveHTTPTableInput{ + requestPath: "/def", + staticCode: http.StatusNoContent, + expectedBody: authenticated, + expectedCode: http.StatusNoContent, + }), + Entry("with status NotFound", &serveHTTPTableInput{ + requestPath: "/ghi", + staticCode: http.StatusNotFound, + expectedBody: authenticated, + expectedCode: http.StatusNotFound, + }), + Entry("with status Teapot", &serveHTTPTableInput{ + requestPath: "/jkl", + staticCode: http.StatusTeapot, + expectedBody: authenticated, + expectedCode: http.StatusTeapot, + }), + ) +}) diff --git a/pkg/upstream/upstream_suite_test.go b/pkg/upstream/upstream_suite_test.go new file mode 100644 index 00000000..7d8c2ba4 --- /dev/null +++ b/pkg/upstream/upstream_suite_test.go @@ -0,0 +1,180 @@ +package upstream + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "os" + "path" + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "golang.org/x/net/websocket" +) + +var ( + filesDir string + server *httptest.Server + serverAddr string + invalidServer = "http://::1" +) + +func TestUpstreamSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + log.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Upstream Suite") +} + +var _ = BeforeSuite(func() { + // Set up files for serving via file servers + dir, err := ioutil.TempDir("", "oauth2-proxy-upstream-suite") + Expect(err).ToNot(HaveOccurred()) + Expect(ioutil.WriteFile(path.Join(dir, "foo"), []byte("foo"), 0644)).To(Succeed()) + Expect(ioutil.WriteFile(path.Join(dir, "bar"), []byte("bar"), 0644)).To(Succeed()) + Expect(os.Mkdir(path.Join(dir, "subdir"), os.ModePerm)).To(Succeed()) + Expect(ioutil.WriteFile(path.Join(dir, "subdir", "baz"), []byte("baz"), 0644)).To(Succeed()) + filesDir = dir + + // Set up a webserver that reflects requests + server = httptest.NewServer(&testHTTPUpstream{}) + serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String()) +}) + +var _ = AfterSuite(func() { + server.Close() + Expect(os.RemoveAll(filesDir)).To(Succeed()) +}) + +const ( + contentType = "Content-Type" + contentLength = "Content-Length" + acceptEncoding = "Accept-Encoding" + applicationJSON = "application/json" + textPlainUTF8 = "text/plain; charset=utf-8" + gapUpstream = "Gap-Upstream-Address" + gapAuth = "Gap-Auth" + gapSignature = "Gap-Signature" +) + +// testHTTPResponse is a struct used for checking responses in table tests +type testHTTPResponse struct { + code int + header http.Header + raw string + request testHTTPRequest +} + +// testHTTPRequest is a struct used to capture the state of a request made to +// an upstream during a test +type testHTTPRequest struct { + Method string + URL string + Header http.Header + Body []byte + Host string + RequestURI string +} + +type testWebSocketResponse struct { + Message string + Origin string +} + +type testHTTPUpstream struct{} + +func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if req.Header.Get("Upgrade") == "websocket" { + t.websocketHandler().ServeHTTP(rw, req) + } else { + t.serveHTTP(rw, req) + } +} + +func (t *testHTTPUpstream) serveHTTP(rw http.ResponseWriter, req *http.Request) { + request, err := toTestHTTPRequest(req) + if err != nil { + t.writeError(rw, err) + return + } + + data, err := json.Marshal(request) + if err != nil { + t.writeError(rw, err) + return + } + + rw.Header().Set("Content-Type", "application/json") + rw.Write(data) +} + +func (t *testHTTPUpstream) websocketHandler() http.Handler { + return websocket.Handler(func(ws *websocket.Conn) { + defer ws.Close() + var data []byte + err := websocket.Message.Receive(ws, &data) + if err != nil { + websocket.Message.Send(ws, []byte(err.Error())) + return + } + + wsResponse := testWebSocketResponse{ + Message: string(data), + Origin: ws.Request().Header.Get("Origin"), + } + err = websocket.JSON.Send(ws, wsResponse) + if err != nil { + websocket.Message.Send(ws, []byte(err.Error())) + return + } + }) +} + +func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) { + rw.WriteHeader(500) + if err != nil { + rw.Write([]byte(err.Error())) + } +} + +func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) { + requestBody := []byte{} + if req.Body != http.NoBody { + var err error + requestBody, err = ioutil.ReadAll(req.Body) + if err != nil { + return testHTTPRequest{}, err + } + } + + return testHTTPRequest{ + Method: req.Method, + URL: req.URL.String(), + Header: req.Header, + Body: requestBody, + Host: req.Host, + RequestURI: req.RequestURI, + }, nil +} + +// String headers added to the response that we do not want to test +func testSanitizeResponseHeader(h http.Header) { + // From HTTP responses + h.Del("Date") + h.Del(contentLength) + + // From File responses + h.Del("Accept-Ranges") + h.Del("Last-Modified") +} + +// Strip the accept header that is added by the HTTP Transport +func testSanitizeRequestHeader(h http.Header) { + h.Del(acceptEncoding) +} diff --git a/pkg/validation/upstreams.go b/pkg/validation/upstreams.go new file mode 100644 index 00000000..be9c0ea7 --- /dev/null +++ b/pkg/validation/upstreams.go @@ -0,0 +1,113 @@ +package validation + +import ( + "fmt" + "net/url" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" +) + +func validateUpstreams(upstreams options.Upstreams) []string { + msgs := []string{} + ids := make(map[string]struct{}) + paths := make(map[string]struct{}) + + for _, upstream := range upstreams { + msgs = append(msgs, validateUpstream(upstream, ids, paths)...) + } + + return msgs +} + +// validateUpstream validates that the upstream has valid options and that +// the ids and paths are unique across all options +func validateUpstream(upstream options.Upstream, ids, paths map[string]struct{}) []string { + msgs := []string{} + + if upstream.ID == "" { + msgs = append(msgs, "upstream has empty id: ids are required for all upstreams") + } + if upstream.Path == "" { + msgs = append(msgs, fmt.Sprintf("upstream %q has empty path: paths are required for all upstreams", upstream.ID)) + } + + // Ensure upstream IDs are unique + if _, ok := ids[upstream.ID]; ok { + msgs = append(msgs, fmt.Sprintf("multiple upstreams found with id %q: upstream ids must be unique", upstream.ID)) + } + ids[upstream.ID] = struct{}{} + + // Ensure upstream Paths are unique + if _, ok := paths[upstream.Path]; ok { + msgs = append(msgs, fmt.Sprintf("multiple upstreams found with path %q: upstream paths must be unique", upstream.Path)) + } + paths[upstream.Path] = struct{}{} + + msgs = append(msgs, validateUpstreamURI(upstream)...) + msgs = append(msgs, validateStaticUpstream(upstream)...) + return msgs +} + +// validateStaticUpstream checks that the StaticCode is only set when Static +// is set, and that any options that do not make sense for a static upstream +// are not set. +func validateStaticUpstream(upstream options.Upstream) []string { + msgs := []string{} + + if !upstream.Static && upstream.StaticCode != nil { + msgs = append(msgs, fmt.Sprintf("upstream %q has staticCode (%d), but is not a static upstream, set 'static' for a static response", upstream.ID, *upstream.StaticCode)) + } + + // Checks after this only make sense when the upstream is static + if !upstream.Static { + return msgs + } + + if upstream.URI != "" { + msgs = append(msgs, fmt.Sprintf("upstream %q has uri, but is a static upstream, this will have no effect.", upstream.ID)) + } + if upstream.InsecureSkipTLSVerify { + msgs = append(msgs, fmt.Sprintf("upstream %q has insecureSkipTLSVerify, but is a static upstream, this will have no effect.", upstream.ID)) + } + if upstream.FlushInterval != nil && *upstream.FlushInterval != time.Second { + msgs = append(msgs, fmt.Sprintf("upstream %q has flushInterval, but is a static upstream, this will have no effect.", upstream.ID)) + } + if !upstream.PassHostHeader { + msgs = append(msgs, fmt.Sprintf("upstream %q has passHostHeader, but is a static upstream, this will have no effect.", upstream.ID)) + } + if !upstream.ProxyWebSockets { + msgs = append(msgs, fmt.Sprintf("upstream %q has proxyWebSockets, but is a static upstream, this will have no effect.", upstream.ID)) + } + + return msgs +} + +func validateUpstreamURI(upstream options.Upstream) []string { + msgs := []string{} + + if !upstream.Static && upstream.URI == "" { + msgs = append(msgs, fmt.Sprintf("upstream %q has empty uri: uris are required for all non-static upstreams", upstream.ID)) + return msgs + } + + // Checks after this only make sense the upstream is not static + if upstream.Static { + return msgs + } + + u, err := url.Parse(upstream.URI) + if err != nil { + msgs = append(msgs, fmt.Sprintf("upstream %q has invalid uri: %v", upstream.ID, err)) + return msgs + } + + switch u.Scheme { + case "http", "https", "file": + // Valid, do nothing + default: + msgs = append(msgs, fmt.Sprintf("upstream %q has invalid scheme: %q", upstream.ID, u.Scheme)) + } + + return msgs +} diff --git a/pkg/validation/upstreams_test.go b/pkg/validation/upstreams_test.go new file mode 100644 index 00000000..4995bf29 --- /dev/null +++ b/pkg/validation/upstreams_test.go @@ -0,0 +1,191 @@ +package validation + +import ( + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Upstreams", func() { + type validateUpstreamTableInput struct { + upstreams options.Upstreams + errStrings []string + } + + flushInterval := 5 * time.Second + staticCode200 := 200 + + validHTTPUpstream := options.Upstream{ + ID: "validHTTPUpstream", + Path: "/validHTTPUpstream", + URI: "http://localhost:8080", + } + validStaticUpstream := options.Upstream{ + ID: "validStaticUpstream", + Path: "/validStaticUpstream", + Static: true, + PassHostHeader: true, // This would normally be defaulted + ProxyWebSockets: true, // this would normally be defaulted + } + validFileUpstream := options.Upstream{ + ID: "validFileUpstream", + Path: "/validFileUpstream", + URI: "file://var/lib/foo", + } + + emptyIDMsg := "upstream has empty id: ids are required for all upstreams" + emptyPathMsg := "upstream \"foo\" has empty path: paths are required for all upstreams" + emptyURIMsg := "upstream \"foo\" has empty uri: uris are required for all non-static upstreams" + invalidURIMsg := "upstream \"foo\" has invalid uri: parse \":\": missing protocol scheme" + invalidURISchemeMsg := "upstream \"foo\" has invalid scheme: \"ftp\"" + staticWithURIMsg := "upstream \"foo\" has uri, but is a static upstream, this will have no effect." + staticWithInsecureMsg := "upstream \"foo\" has insecureSkipTLSVerify, but is a static upstream, this will have no effect." + staticWithFlushIntervalMsg := "upstream \"foo\" has flushInterval, but is a static upstream, this will have no effect." + staticWithPassHostHeaderMsg := "upstream \"foo\" has passHostHeader, but is a static upstream, this will have no effect." + staticWithProxyWebSocketsMsg := "upstream \"foo\" has proxyWebSockets, but is a static upstream, this will have no effect." + multipleIDsMsg := "multiple upstreams found with id \"foo\": upstream ids must be unique" + multiplePathsMsg := "multiple upstreams found with path \"/foo\": upstream paths must be unique" + staticCodeMsg := "upstream \"foo\" has staticCode (200), but is not a static upstream, set 'static' for a static response" + + DescribeTable("validateUpstreams", + func(o *validateUpstreamTableInput) { + Expect(validateUpstreams(o.upstreams)).To(ConsistOf(o.errStrings)) + }, + Entry("with no upstreams", &validateUpstreamTableInput{ + upstreams: options.Upstreams{}, + errStrings: []string{}, + }), + Entry("with valid upstreams", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + validHTTPUpstream, + validStaticUpstream, + validFileUpstream, + }, + errStrings: []string{}, + }), + Entry("with an empty ID", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "", + Path: "/foo", + URI: "http://localhost:8080", + }, + }, + errStrings: []string{emptyIDMsg}, + }), + Entry("with an empty Path", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo", + Path: "", + URI: "http://localhost:8080", + }, + }, + errStrings: []string{emptyPathMsg}, + }), + Entry("with an empty Path", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo", + Path: "", + URI: "http://localhost:8080", + }, + }, + errStrings: []string{emptyPathMsg}, + }), + Entry("with an empty URI", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo", + Path: "/foo", + URI: "", + }, + }, + errStrings: []string{emptyURIMsg}, + }), + Entry("with an invalid URI", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo", + Path: "/foo", + URI: ":", + }, + }, + errStrings: []string{invalidURIMsg}, + }), + Entry("with an invalid URI scheme", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo", + Path: "/foo", + URI: "ftp://foo", + }, + }, + errStrings: []string{invalidURISchemeMsg}, + }), + Entry("with a static upstream and invalid optons", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo", + Path: "/foo", + URI: "ftp://foo", + Static: true, + FlushInterval: &flushInterval, + PassHostHeader: false, + ProxyWebSockets: false, + InsecureSkipTLSVerify: true, + }, + }, + errStrings: []string{ + staticWithURIMsg, + staticWithInsecureMsg, + staticWithFlushIntervalMsg, + staticWithPassHostHeaderMsg, + staticWithProxyWebSocketsMsg, + }, + }), + Entry("with duplicate IDs", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo", + Path: "/foo1", + URI: "http://foo", + }, + { + ID: "foo", + Path: "/foo2", + URI: "http://foo", + }, + }, + errStrings: []string{multipleIDsMsg}, + }), + Entry("with duplicate Paths", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo1", + Path: "/foo", + URI: "http://foo", + }, + { + ID: "foo2", + Path: "/foo", + URI: "http://foo", + }, + }, + errStrings: []string{multiplePathsMsg}, + }), + Entry("when a static code is supplied without static", &validateUpstreamTableInput{ + upstreams: options.Upstreams{ + { + ID: "foo", + Path: "/foo", + StaticCode: &staticCode200, + }, + }, + errStrings: []string{emptyURIMsg, staticCodeMsg}, + }), + ) +}) diff --git a/pkg/validation/validation_suite_test.go b/pkg/validation/validation_suite_test.go new file mode 100644 index 00000000..2c6458fe --- /dev/null +++ b/pkg/validation/validation_suite_test.go @@ -0,0 +1,16 @@ +package validation + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestValidationSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Validation Suite") +}