Merge pull request #1142 from oauth2-proxy/writer-funcs
Add pagewriter to upstream proxy
This commit is contained in:
		
						commit
						06808704a3
					
				|  | @ -8,6 +8,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v7.1.3 | ## Changes since v7.1.3 | ||||||
| 
 | 
 | ||||||
|  | - [#1142](https://github.com/oauth2-proxy/oauth2-proxy/pull/1142) Add pagewriter to upstream proxy (@JoelSpeed) | ||||||
| - [#1181](https://github.com/oauth2-proxy/oauth2-proxy/pull/1181) Fix incorrect `cfg` name in show-debug-on-error flag (@iTaybb) | - [#1181](https://github.com/oauth2-proxy/oauth2-proxy/pull/1181) Fix incorrect `cfg` name in show-debug-on-error flag (@iTaybb) | ||||||
| 
 | 
 | ||||||
| # V7.1.3 | # V7.1.3 | ||||||
|  |  | ||||||
|  | @ -124,7 +124,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		return nil, fmt.Errorf("error initialising page writer: %v", err) | 		return nil, fmt.Errorf("error initialising page writer: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter.ProxyErrorHandler) | 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("error initialising upstream proxy: %v", err) | 		return nil, fmt.Errorf("error initialising upstream proxy: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -101,3 +101,73 @@ func NewWriter(opts Opts) (Writer, error) { | ||||||
| 		staticPageWriter: staticPages, | 		staticPageWriter: staticPages, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // WriterFuncs is an implementation of the PageWriter interface based
 | ||||||
|  | // on override functions.
 | ||||||
|  | // If any of the funcs are not provided, a default implementation will be used.
 | ||||||
|  | // This is primarily for us in testing.
 | ||||||
|  | type WriterFuncs struct { | ||||||
|  | 	SignInPageFunc func(rw http.ResponseWriter, req *http.Request, redirectURL string) | ||||||
|  | 	ErrorPageFunc  func(rw http.ResponseWriter, opts ErrorPageOpts) | ||||||
|  | 	ProxyErrorFunc func(rw http.ResponseWriter, req *http.Request, proxyErr error) | ||||||
|  | 	RobotsTxtfunc  func(rw http.ResponseWriter, req *http.Request) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WriteSignInPage implements the Writer interface.
 | ||||||
|  | // If the SignInPageFunc is provided, this will be used, else a default
 | ||||||
|  | // implementation will be used.
 | ||||||
|  | func (w *WriterFuncs) WriteSignInPage(rw http.ResponseWriter, req *http.Request, redirectURL string) { | ||||||
|  | 	if w.SignInPageFunc != nil { | ||||||
|  | 		w.SignInPageFunc(rw, req, redirectURL) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, err := rw.Write([]byte("Sign In")); err != nil { | ||||||
|  | 		rw.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WriteErrorPage implements the Writer interface.
 | ||||||
|  | // If the ErrorPageFunc is provided, this will be used, else a default
 | ||||||
|  | // implementation will be used.
 | ||||||
|  | func (w *WriterFuncs) WriteErrorPage(rw http.ResponseWriter, opts ErrorPageOpts) { | ||||||
|  | 	if w.ErrorPageFunc != nil { | ||||||
|  | 		w.ErrorPageFunc(rw, opts) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rw.WriteHeader(opts.Status) | ||||||
|  | 	errMsg := fmt.Sprintf("%d - %v", opts.Status, opts.AppError) | ||||||
|  | 	if _, err := rw.Write([]byte(errMsg)); err != nil { | ||||||
|  | 		rw.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ProxyErrorHandler implements the Writer interface.
 | ||||||
|  | // If the ProxyErrorFunc is provided, this will be used, else a default
 | ||||||
|  | // implementation will be used.
 | ||||||
|  | func (w *WriterFuncs) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { | ||||||
|  | 	if w.ProxyErrorFunc != nil { | ||||||
|  | 		w.ProxyErrorFunc(rw, req, proxyErr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	w.WriteErrorPage(rw, ErrorPageOpts{ | ||||||
|  | 		Status:   http.StatusBadGateway, | ||||||
|  | 		AppError: proxyErr.Error(), | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WriteRobotsTxt implements the Writer interface.
 | ||||||
|  | // If the RobotsTxtfunc is provided, this will be used, else a default
 | ||||||
|  | // implementation will be used.
 | ||||||
|  | func (w *WriterFuncs) WriteRobotsTxt(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 	if w.RobotsTxtfunc != nil { | ||||||
|  | 		w.RobotsTxtfunc(rw, req) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, err := rw.Write([]byte("Allow: *")); err != nil { | ||||||
|  | 		rw.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -1,6 +1,8 @@ | ||||||
| package pagewriter | package pagewriter | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | @ -8,6 +10,7 @@ import ( | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 
 | 
 | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -135,4 +138,144 @@ var _ = Describe("Writer", func() { | ||||||
| 			}) | 			}) | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("WriterFuncs", func() { | ||||||
|  | 		type writerFuncsTableInput struct { | ||||||
|  | 			writer         Writer | ||||||
|  | 			expectedStatus int | ||||||
|  | 			expectedBody   string | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("WriteSignInPage", | ||||||
|  | 			func(in writerFuncsTableInput) { | ||||||
|  | 				rw := httptest.NewRecorder() | ||||||
|  | 				req := httptest.NewRequest("", "/sign-in", nil) | ||||||
|  | 				redirectURL := "<redirectURL>" | ||||||
|  | 				in.writer.WriteSignInPage(rw, req, redirectURL) | ||||||
|  | 
 | ||||||
|  | 				Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(rw.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal(in.expectedBody)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("With no override", writerFuncsTableInput{ | ||||||
|  | 				writer:         &WriterFuncs{}, | ||||||
|  | 				expectedStatus: 200, | ||||||
|  | 				expectedBody:   "Sign In", | ||||||
|  | 			}), | ||||||
|  | 			Entry("With an override function", writerFuncsTableInput{ | ||||||
|  | 				writer: &WriterFuncs{ | ||||||
|  | 					SignInPageFunc: func(rw http.ResponseWriter, req *http.Request, redirectURL string) { | ||||||
|  | 						rw.WriteHeader(202) | ||||||
|  | 						rw.Write([]byte(fmt.Sprintf("%s %s", req.URL.Path, redirectURL))) | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedStatus: 202, | ||||||
|  | 				expectedBody:   "/sign-in <redirectURL>", | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("WriteErrorPage", | ||||||
|  | 			func(in writerFuncsTableInput) { | ||||||
|  | 				rw := httptest.NewRecorder() | ||||||
|  | 				in.writer.WriteErrorPage(rw, ErrorPageOpts{ | ||||||
|  | 					Status:      http.StatusInternalServerError, | ||||||
|  | 					RedirectURL: "<redirectURL>", | ||||||
|  | 					RequestID:   "12345", | ||||||
|  | 					AppError:    "application error", | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(rw.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal(in.expectedBody)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("With no override", writerFuncsTableInput{ | ||||||
|  | 				writer:         &WriterFuncs{}, | ||||||
|  | 				expectedStatus: 500, | ||||||
|  | 				expectedBody:   "500 - application error", | ||||||
|  | 			}), | ||||||
|  | 			Entry("With an override function", writerFuncsTableInput{ | ||||||
|  | 				writer: &WriterFuncs{ | ||||||
|  | 					ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) { | ||||||
|  | 						rw.WriteHeader(503) | ||||||
|  | 						rw.Write([]byte(fmt.Sprintf("%s %s", opts.RequestID, opts.RedirectURL))) | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedStatus: 503, | ||||||
|  | 				expectedBody:   "12345 <redirectURL>", | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("ProxyErrorHandler", | ||||||
|  | 			func(in writerFuncsTableInput) { | ||||||
|  | 				rw := httptest.NewRecorder() | ||||||
|  | 				req := httptest.NewRequest("", "/proxy", nil) | ||||||
|  | 				err := errors.New("proxy error") | ||||||
|  | 				in.writer.ProxyErrorHandler(rw, req, err) | ||||||
|  | 
 | ||||||
|  | 				Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(rw.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal(in.expectedBody)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("With no override", writerFuncsTableInput{ | ||||||
|  | 				writer:         &WriterFuncs{}, | ||||||
|  | 				expectedStatus: 502, | ||||||
|  | 				expectedBody:   "502 - proxy error", | ||||||
|  | 			}), | ||||||
|  | 			Entry("With an override function for the proxy handler", writerFuncsTableInput{ | ||||||
|  | 				writer: &WriterFuncs{ | ||||||
|  | 					ProxyErrorFunc: func(rw http.ResponseWriter, req *http.Request, proxyErr error) { | ||||||
|  | 						rw.WriteHeader(503) | ||||||
|  | 						rw.Write([]byte(fmt.Sprintf("%s %v", req.URL.Path, proxyErr))) | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedStatus: 503, | ||||||
|  | 				expectedBody:   "/proxy proxy error", | ||||||
|  | 			}), | ||||||
|  | 			Entry("With an override function for the error page", writerFuncsTableInput{ | ||||||
|  | 				writer: &WriterFuncs{ | ||||||
|  | 					ErrorPageFunc: func(rw http.ResponseWriter, opts ErrorPageOpts) { | ||||||
|  | 						rw.WriteHeader(500) | ||||||
|  | 						rw.Write([]byte("Internal Server Error")) | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedStatus: 500, | ||||||
|  | 				expectedBody:   "Internal Server Error", | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("WriteRobotsTxt", | ||||||
|  | 			func(in writerFuncsTableInput) { | ||||||
|  | 				rw := httptest.NewRecorder() | ||||||
|  | 				req := httptest.NewRequest("", "/robots.txt", nil) | ||||||
|  | 				in.writer.WriteRobotsTxt(rw, req) | ||||||
|  | 
 | ||||||
|  | 				Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(rw.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal(in.expectedBody)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("With no override", writerFuncsTableInput{ | ||||||
|  | 				writer:         &WriterFuncs{}, | ||||||
|  | 				expectedStatus: 200, | ||||||
|  | 				expectedBody:   "Allow: *", | ||||||
|  | 			}), | ||||||
|  | 			Entry("With an override function", writerFuncsTableInput{ | ||||||
|  | 				writer: &WriterFuncs{ | ||||||
|  | 					RobotsTxtfunc: func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 						rw.WriteHeader(202) | ||||||
|  | 						rw.Write([]byte("Disallow: *")) | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedStatus: 202, | ||||||
|  | 				expectedBody:   "Disallow: *", | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
| }) | }) | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -15,7 +16,7 @@ type ProxyErrorHandler func(http.ResponseWriter, *http.Request, error) | ||||||
| 
 | 
 | ||||||
| // NewProxy creates a new multiUpstreamProxy that can serve requests directed to
 | // NewProxy creates a new multiUpstreamProxy that can serve requests directed to
 | ||||||
| // multiple upstreams.
 | // multiple upstreams.
 | ||||||
| func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, errorHandler ProxyErrorHandler) (http.Handler, error) { | func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, writer pagewriter.Writer) (http.Handler, error) { | ||||||
| 	m := &multiUpstreamProxy{ | 	m := &multiUpstreamProxy{ | ||||||
| 		serveMux: http.NewServeMux(), | 		serveMux: http.NewServeMux(), | ||||||
| 	} | 	} | ||||||
|  | @ -34,7 +35,7 @@ func NewProxy(upstreams options.Upstreams, sigData *options.SignatureData, error | ||||||
| 		case fileScheme: | 		case fileScheme: | ||||||
| 			m.registerFileServer(upstream, u) | 			m.registerFileServer(upstream, u) | ||||||
| 		case httpScheme, httpsScheme: | 		case httpScheme, httpsScheme: | ||||||
| 			m.registerHTTPUpstreamProxy(upstream, u, sigData, errorHandler) | 			m.registerHTTPUpstreamProxy(upstream, u, sigData, writer) | ||||||
| 		default: | 		default: | ||||||
| 			return nil, fmt.Errorf("unknown scheme for upstream %q: %q", upstream.ID, u.Scheme) | 			return nil, fmt.Errorf("unknown scheme for upstream %q: %q", upstream.ID, u.Scheme) | ||||||
| 		} | 		} | ||||||
|  | @ -66,7 +67,7 @@ func (m *multiUpstreamProxy) registerFileServer(upstream options.Upstream, u *ur | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // registerHTTPUpstreamProxy registers a new httpUpstreamProxy based on the configuration given.
 | // registerHTTPUpstreamProxy registers a new httpUpstreamProxy based on the configuration given.
 | ||||||
| func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, errorHandler ProxyErrorHandler) { | func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, writer pagewriter.Writer) { | ||||||
| 	logger.Printf("mapping path %q => upstream %q", upstream.Path, upstream.URI) | 	logger.Printf("mapping path %q => upstream %q", upstream.Path, upstream.URI) | ||||||
| 	m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, errorHandler)) | 	m.serveMux.Handle(upstream.Path, newHTTPUpstreamProxy(upstream, u, sigData, writer.ProxyErrorHandler)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -9,6 +9,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -20,9 +21,11 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 	BeforeEach(func() { | 	BeforeEach(func() { | ||||||
| 		sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} | 		sigData := &options.SignatureData{Hash: crypto.SHA256, Key: "secret"} | ||||||
| 
 | 
 | ||||||
| 		errorHandler := func(rw http.ResponseWriter, _ *http.Request, _ error) { | 		writer := &pagewriter.WriterFuncs{ | ||||||
| 			rw.WriteHeader(502) | 			ProxyErrorFunc: func(rw http.ResponseWriter, _ *http.Request, _ error) { | ||||||
| 			rw.Write([]byte("Proxy Error")) | 				rw.WriteHeader(502) | ||||||
|  | 				rw.Write([]byte("Proxy Error")) | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		ok := http.StatusOK | 		ok := http.StatusOK | ||||||
|  | @ -58,7 +61,7 @@ var _ = Describe("Proxy Suite", func() { | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		var err error | 		var err error | ||||||
| 		upstreamServer, err = NewProxy(upstreams, sigData, errorHandler) | 		upstreamServer, err = NewProxy(upstreams, sigData, writer) | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue