Merge pull request #1142 from oauth2-proxy/writer-funcs
Add pagewriter to upstream proxy
This commit is contained in:
		
						commit
						06808704a3
					
				|  | @ -8,6 +8,7 @@ | |||
| 
 | ||||
| ## Changes since v7.1.3 | ||||
| 
 | ||||
| - [#1142](https://github.com/oauth2-proxy/oauth2-proxy/pull/1142) Add pagewriter to upstream proxy (@JoelSpeed) | ||||
| - [#1181](https://github.com/oauth2-proxy/oauth2-proxy/pull/1181) Fix incorrect `cfg` name in show-debug-on-error flag (@iTaybb) | ||||
| 
 | ||||
| # V7.1.3 | ||||
|  |  | |||
|  | @ -124,7 +124,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		return nil, fmt.Errorf("error initialising page writer: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter.ProxyErrorHandler) | ||||
| 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error initialising upstream proxy: %v", err) | ||||
| 	} | ||||
|  |  | |||
|  | @ -101,3 +101,73 @@ func NewWriter(opts Opts) (Writer, error) { | |||
| 		staticPageWriter: staticPages, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| // WriterFuncs is an implementation of the PageWriter interface based
 | ||||
| // on override functions.
 | ||||
| // If any of the funcs are not provided, a default implementation will be used.
 | ||||
| // This is primarily for us in testing.
 | ||||
| type WriterFuncs struct { | ||||
| 	SignInPageFunc func(rw http.ResponseWriter, req *http.Request, redirectURL string) | ||||
| 	ErrorPageFunc  func(rw http.ResponseWriter, opts ErrorPageOpts) | ||||
| 	ProxyErrorFunc func(rw http.ResponseWriter, req *http.Request, proxyErr error) | ||||
| 	RobotsTxtfunc  func(rw http.ResponseWriter, req *http.Request) | ||||
| } | ||||
| 
 | ||||
| // WriteSignInPage implements the Writer interface.
 | ||||
| // If the SignInPageFunc is provided, this will be used, else a default
 | ||||
| // implementation will be used.
 | ||||
| func (w *WriterFuncs) WriteSignInPage(rw http.ResponseWriter, req *http.Request, redirectURL string) { | ||||
| 	if w.SignInPageFunc != nil { | ||||
| 		w.SignInPageFunc(rw, req, redirectURL) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if _, err := rw.Write([]byte("Sign In")); err != nil { | ||||
| 		rw.WriteHeader(http.StatusInternalServerError) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WriteErrorPage implements the Writer interface.
 | ||||
| // If the ErrorPageFunc is provided, this will be used, else a default
 | ||||
| // implementation will be used.
 | ||||
| func (w *WriterFuncs) WriteErrorPage(rw http.ResponseWriter, opts ErrorPageOpts) { | ||||
| 	if w.ErrorPageFunc != nil { | ||||
| 		w.ErrorPageFunc(rw, opts) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	rw.WriteHeader(opts.Status) | ||||
| 	errMsg := fmt.Sprintf("%d - %v", opts.Status, opts.AppError) | ||||
| 	if _, err := rw.Write([]byte(errMsg)); err != nil { | ||||
| 		rw.WriteHeader(http.StatusInternalServerError) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // ProxyErrorHandler implements the Writer interface.
 | ||||
| // If the ProxyErrorFunc is provided, this will be used, else a default
 | ||||
| // implementation will be used.
 | ||||
| func (w *WriterFuncs) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { | ||||
| 	if w.ProxyErrorFunc != nil { | ||||
| 		w.ProxyErrorFunc(rw, req, proxyErr) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	w.WriteErrorPage(rw, ErrorPageOpts{ | ||||
| 		Status:   http.StatusBadGateway, | ||||
| 		AppError: proxyErr.Error(), | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // WriteRobotsTxt implements the Writer interface.
 | ||||
| // If the RobotsTxtfunc is provided, this will be used, else a default
 | ||||
| // implementation will be used.
 | ||||
| func (w *WriterFuncs) WriteRobotsTxt(rw http.ResponseWriter, req *http.Request) { | ||||
| 	if w.RobotsTxtfunc != nil { | ||||
| 		w.RobotsTxtfunc(rw, req) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if _, err := rw.Write([]byte("Allow: *")); err != nil { | ||||
| 		rw.WriteHeader(http.StatusInternalServerError) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -1,6 +1,8 @@ | |||
| package pagewriter | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
|  | @ -8,6 +10,7 @@ import ( | |||
| 	"path/filepath" | ||||
| 
 | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
|  | @ -135,4 +138,144 @@ var _ = Describe("Writer", func() { | |||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("WriterFuncs", func() { | ||||
| 		type writerFuncsTableInput struct { | ||||
| 			writer         Writer | ||||
| 			expectedStatus int | ||||
| 			expectedBody   string | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("WriteSignInPage", | ||||
| 			func(in writerFuncsTableInput) { | ||||
| 				rw := httptest.NewRecorder() | ||||
| 				req := httptest.NewRequest("", "/sign-in", nil) | ||||
| 				redirectURL := "<redirectURL>" | ||||
| 				in.writer.WriteSignInPage(rw, req, redirectURL) | ||||
| 
 | ||||
| 				Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) | ||||
| 
 | ||||
| 				body, err := ioutil.ReadAll(rw.Result().Body) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(body)).To(Equal(in.expectedBody)) | ||||
| 			}, | ||||
| 			Entry("With no override", writerFuncsTableInput{ | ||||
| 				writer:         &WriterFuncs{}, | ||||
| 				expectedStatus: 200, | ||||
| 				expectedBody:   "Sign In", | ||||
| 			}), | ||||
| 			Entry("With an override function", writerFuncsTableInput{ | ||||
| 				writer: &WriterFuncs{ | ||||
| 					SignInPageFunc: func(rw http.ResponseWriter, req *http.Request, redirectURL string) { | ||||
| 						rw.WriteHeader(202) | ||||
| 						rw.Write([]byte(fmt.Sprintf("%s %s", req.URL.Path, redirectURL))) | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedStatus: 202, | ||||
| 				expectedBody:   "/sign-in <redirectURL>", | ||||
| 			}), | ||||
| 		) | ||||
| 
 | ||||
| 		DescribeTable("WriteErrorPage", | ||||
| 			func(in writerFuncsTableInput) { | ||||
| 				rw := httptest.NewRecorder() | ||||
| 				in.writer.WriteErrorPage(rw, ErrorPageOpts{ | ||||
| 					Status:      http.StatusInternalServerError, | ||||
| 					RedirectURL: "<redirectURL>", | ||||
| 					RequestID:   "12345", | ||||
| 					AppError:    "application error", | ||||
| 				}) | ||||
| 
 | ||||
| 				Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) | ||||
| 
 | ||||
| 				body, err := ioutil.ReadAll(rw.Result().Body) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(body)).To(Equal(in.expectedBody)) | ||||
| 			}, | ||||
| 			Entry("With no override", writerFuncsTableInput{ | ||||
| 				writer:         &WriterFuncs{}, | ||||
| 				expectedStatus: 500, | ||||
| 				expectedBody:   "500 - application error", | ||||
| 			}), | ||||
| 			Entry("With an override function", writerFuncsTableInput{ | ||||
| 				writer: &WriterFuncs{ | ||||
| 					ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) { | ||||
| 						rw.WriteHeader(503) | ||||
| 						rw.Write([]byte(fmt.Sprintf("%s %s", opts.RequestID, opts.RedirectURL))) | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedStatus: 503, | ||||
| 				expectedBody:   "12345 <redirectURL>", | ||||
| 			}), | ||||
| 		) | ||||
| 
 | ||||
| 		DescribeTable("ProxyErrorHandler", | ||||
| 			func(in writerFuncsTableInput) { | ||||
| 				rw := httptest.NewRecorder() | ||||
| 				req := httptest.NewRequest("", "/proxy", nil) | ||||
| 				err := errors.New("proxy error") | ||||
| 				in.writer.ProxyErrorHandler(rw, req, err) | ||||
| 
 | ||||
| 				Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) | ||||
| 
 | ||||
| 				body, err := ioutil.ReadAll(rw.Result().Body) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(body)).To(Equal(in.expectedBody)) | ||||
| 			}, | ||||
| 			Entry("With no override", writerFuncsTableInput{ | ||||
| 				writer:         &WriterFuncs{}, | ||||
| 				expectedStatus: 502, | ||||
| 				expectedBody:   "502 - proxy error", | ||||
| 			}), | ||||
| 			Entry("With an override function for the proxy handler", writerFuncsTableInput{ | ||||
| 				writer: &WriterFuncs{ | ||||
| 					ProxyErrorFunc: func(rw http.ResponseWriter, req *http.Request, proxyErr error) { | ||||
| 						rw.WriteHeader(503) | ||||
| 						rw.Write([]byte(fmt.Sprintf("%s %v", req.URL.Path, proxyErr))) | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedStatus: 503, | ||||
| 				expectedBody:   "/proxy proxy error", | ||||
| 			}), | ||||
| 			Entry("With an override function for the error page", writerFuncsTableInput{ | ||||
| 				writer: &WriterFuncs{ | ||||
| 					ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) { | ||||
| 						rw.WriteHeader(500) | ||||
| 						rw.Write([]byte("Internal Server Error")) | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedStatus: 500, | ||||
| 				expectedBody:   "Internal Server Error", | ||||
| 			}), | ||||
| 		) | ||||
| 
 | ||||
| 		DescribeTable("WriteRobotsTxt", | ||||
| 			func(in writerFuncsTableInput) { | ||||
| 				rw := httptest.NewRecorder() | ||||
| 				req := httptest.NewRequest("", "/robots.txt", nil) | ||||
| 				in.writer.WriteRobotsTxt(rw, req) | ||||
| 
 | ||||
| 				Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) | ||||
| 
 | ||||
| 				body, err := ioutil.ReadAll(rw.Result().Body) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(body)).To(Equal(in.expectedBody)) | ||||
| 			}, | ||||
| 			Entry("With no override", writerFuncsTableInput{ | ||||
| 				writer:         &WriterFuncs{}, | ||||
| 				expectedStatus: 200, | ||||
| 				expectedBody:   "Allow: *", | ||||
| 			}), | ||||
| 			Entry("With an override function", writerFuncsTableInput{ | ||||
| 				writer: &WriterFuncs{ | ||||
| 					RobotsTxtfunc: func(rw http.ResponseWriter, req *http.Request) { | ||||
| 						rw.WriteHeader(202) | ||||
| 						rw.Write([]byte("Disallow: *")) | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedStatus: 202, | ||||
| 				expectedBody:   "Disallow: *", | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| }) | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
|  | @ -15,7 +16,7 @@ type ProxyErrorHandler func(http.ResponseWriter, *http.Request, error) | |||
| 
 | ||||
| // NewProxy creates a new multiUpstreamProxy that can serve requests directed to
 | ||||
| // multiple upstreams.
 | ||||
| func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, errorHandler ProxyErrorHandler) (http.Handler, error) { | ||||
| func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, writer pagewriter.Writer) (http.Handler, error) { | ||||
| 	m := &multiUpstreamProxy{ | ||||
| 		serveMux: http.NewServeMux(), | ||||
| 	} | ||||
|  | @ -34,7 +35,7 @@ func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, error | |||
| 		case fileScheme: | ||||
| 			m.registerFileServer(upstream, u) | ||||
| 		case httpScheme, httpsScheme: | ||||
| 			m.registerHTTPUpstreamProxy(upstream, u, sigData, errorHandler) | ||||
| 			m.registerHTTPUpstreamProxy(upstream, u, sigData, writer) | ||||
| 		default: | ||||
| 			return nil, fmt.Errorf("unknown scheme for upstream %q: %q", upstream.ID, u.Scheme) | ||||
| 		} | ||||
|  | @ -66,7 +67,7 @@ func (m *multiUpstreamProxy) registerFileServer(upstream options.Upstream, u *ur | |||
| } | ||||
| 
 | ||||
| // registerHTTPUpstreamProxy registers a new httpUpstreamProxy based on the configuration given.
 | ||||
| func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, errorHandler ProxyErrorHandler) { | ||||
| func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, writer pagewriter.Writer) { | ||||
| 	logger.Printf("mapping path %q => upstream %q", upstream.Path, upstream.URI) | ||||
| 	m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, errorHandler)) | ||||
| 	m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, writer.ProxyErrorHandler)) | ||||
| } | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ import ( | |||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
|  | @ -20,9 +21,11 @@ var _ = Describe("Proxy Suite", func() { | |||
| 	BeforeEach(func() { | ||||
| 		sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} | ||||
| 
 | ||||
| 		errorHandler := func(rw http.ResponseWriter, _ *http.Request, _ error) { | ||||
| 		writer := &pagewriter.WriterFuncs{ | ||||
| 			ProxyErrorFunc: func(rw http.ResponseWriter, _ *http.Request, _ error) { | ||||
| 				rw.WriteHeader(502) | ||||
| 				rw.Write([]byte("Proxy Error")) | ||||
| 			}, | ||||
| 		} | ||||
| 
 | ||||
| 		ok := http.StatusOK | ||||
|  | @ -58,7 +61,7 @@ var _ = Describe("Proxy Suite", func() { | |||
| 		} | ||||
| 
 | ||||
| 		var err error | ||||
| 		upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) | ||||
| 		upstreamServer, err = NewProxy(upstreams, sigData, writer) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 	}) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue