diff --git a/oauthproxy.go b/oauthproxy.go index 50051bff..63e1e48f 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -101,8 +101,7 @@ type OAuthProxy struct { sessionChain alice.Chain headersChain alice.Chain preAuthChain alice.Chain - errorPage *app.ErrorPage - signInPage *app.SignInPage + pageWriter app.PageWriter } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided @@ -112,20 +111,31 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, fmt.Errorf("error initialising session store: %v", err) } - templates, err := app.LoadTemplates(opts.Templates.Path) + var basicAuthValidator basic.Validator + if opts.HtpasswdFile != "" { + logger.Printf("using htpasswd file: %s", opts.HtpasswdFile) + var err error + basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile) + if err != nil { + return nil, fmt.Errorf("could not load htpasswdfile: %v", err) + } + } + + pageWriter, err := app.NewPageWriter(app.PageWriterOpts{ + TemplatesPath: opts.Templates.Path, + ProxyPrefix: opts.ProxyPrefix, + Footer: opts.Templates.Footer, + Version: VERSION, + Debug: opts.Templates.Debug, + ProviderName: buildProviderName(opts.GetProvider(), opts.ProviderName), + SignInMessage: buildSignInMessage(opts), + DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, + }) if err != nil { - return nil, fmt.Errorf("error loading templates: %v", err) + return nil, fmt.Errorf("error initialising page writer: %v", err) } - errorPage := &app.ErrorPage{ - Template: templates.Lookup("error.html"), - ProxyPrefix: opts.ProxyPrefix, - Footer: opts.Templates.Footer, - Version: VERSION, - Debug: opts.Templates.Debug, - } - - upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), errorPage.ProxyErrorHandler) + upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter.ProxyErrorHandler) if err != nil { return nil, fmt.Errorf("error initialising upstream proxy: %v", err) } @@ -158,27 +168,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr } } - var basicAuthValidator basic.Validator - if opts.HtpasswdFile != "" { - logger.Printf("using htpasswd file: %s", opts.HtpasswdFile) - var err error - basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile) - if err != nil { - return nil, fmt.Errorf("could not load htpasswdfile: %v", err) - } - } - - signInPage := &app.SignInPage{ - Template: templates.Lookup("sign_in.html"), - ErrorPage: errorPage, - ProxyPrefix: opts.ProxyPrefix, - ProviderName: buildProviderName(opts.GetProvider(), opts.ProviderName), - SignInMessage: buildSignInMessage(opts), - Footer: opts.Templates.Footer, - Version: VERSION, - DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, - } - allowedRoutes, err := buildRoutesAllowlist(opts) if err != nil { return nil, err @@ -232,8 +221,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr sessionChain: sessionChain, headersChain: headersChain, preAuthChain: preAuthChain, - errorPage: errorPage, - signInPage: signInPage, + pageWriter: pageWriter, }, nil } @@ -540,7 +528,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i redirectURL = "/" } - p.errorPage.Render(rw, code, redirectURL, appError, messages...) + p.pageWriter.WriteErrorPage(rw, code, redirectURL, appError, messages...) } // IsAllowedRequest is used to check if auth should be skipped for this request @@ -601,7 +589,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code redirectURL = "/" } - p.signInPage.Render(rw, redirectURL) + p.pageWriter.WriteSignInPage(rw, redirectURL) } // ManualSignIn handles basic auth logins to the proxy diff --git a/pkg/app/error_page.go b/pkg/app/error_page.go index 56d1c6af..62663838 100644 --- a/pkg/app/error_page.go +++ b/pkg/app/error_page.go @@ -17,30 +17,30 @@ var errorMessages = map[int]string{ http.StatusUnauthorized: "You need to be logged in to access this resource.", } -// ErrorPage is used to render error pages. -type ErrorPage struct { - // Template is the error page HTML template. - Template *template.Template +// errorPageWriter is used to render error pages. +type errorPageWriter struct { + // template is the error page HTML template. + template *template.Template - // ProxyPrefix is the prefix under which OAuth2 Proxy pages are served. - ProxyPrefix string + // proxyPrefix is the prefix under which OAuth2 Proxy pages are served. + proxyPrefix string - // Footer is the footer to be displayed at the bottom of the page. + // footer is the footer to be displayed at the bottom of the page. // If not set, a default footer will be used. - Footer string + footer string - // Version is the OAuth2 Proxy version to be used in the default footer. - Version string + // version is the OAuth2 Proxy version to be used in the default footer. + version string - // Debug determines whether errors pages should be rendered with detailed + // debug determines whether errors pages should be rendered with detailed // errors. - Debug bool + debug bool } -// Render writes an error page to the given response writer. +// WriteErrorPage writes an error page to the given response writer. // It uses the passed redirectURL to give users the option to go back to where // they originally came from or try signing in again. -func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) { +func (e *errorPageWriter) WriteErrorPage(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) { rw.WriteHeader(status) // We allow unescaped template.HTML since it is user configured options @@ -56,14 +56,14 @@ func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL strin }{ Title: http.StatusText(status), Message: e.getMessage(status, appError, messages...), - ProxyPrefix: e.ProxyPrefix, + ProxyPrefix: e.proxyPrefix, StatusCode: status, Redirect: redirectURL, - Footer: template.HTML(e.Footer), - Version: e.Version, + Footer: template.HTML(e.footer), + Version: e.version, } - if err := e.Template.Execute(rw, data); err != nil { + if err := e.template.Execute(rw, data); err != nil { logger.Printf("Error rendering error template: %v", err) http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } @@ -72,18 +72,18 @@ func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL strin // ProxyErrorHandler is used by the upstream ReverseProxy to render error pages // when there are issues with upstream servers. // It is expected to always render a bad gateway error. -func (e *ErrorPage) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { +func (e *errorPageWriter) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { logger.Errorf("Error proxying to upstream server: %v", proxyErr) - e.Render(rw, http.StatusBadGateway, "", proxyErr.Error(), "There was a problem connecting to the upstream server.") + e.WriteErrorPage(rw, http.StatusBadGateway, "", proxyErr.Error(), "There was a problem connecting to the upstream server.") } // getMessage creates the message for the template parameters. -// If the ErrorPage.Debug is enabled, the application error takes precedence. +// If the errorPagewriter.Debug is enabled, the application error takes precedence. // Otherwise, any messages will be used. // The first message is expected to be a format string. // If no messages are supplied, a default error message will be used. -func (e *ErrorPage) getMessage(status int, appError string, messages ...interface{}) string { - if e.Debug { +func (e *errorPageWriter) getMessage(status int, appError string, messages ...interface{}) string { + if e.debug { return appError } if len(messages) > 0 { diff --git a/pkg/app/error_page_test.go b/pkg/app/error_page_test.go index 5c4f78fa..28880d68 100644 --- a/pkg/app/error_page_test.go +++ b/pkg/app/error_page_test.go @@ -10,25 +10,25 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Error Page", func() { - var errorPage *ErrorPage +var _ = Describe("Error Page Writer", func() { + var errorPage *errorPageWriter BeforeEach(func() { tmpl, err := template.New("").Parse("{{.Title}} {{.Message}} {{.ProxyPrefix}} {{.StatusCode}} {{.Redirect}} {{.Footer}} {{.Version}}") Expect(err).ToNot(HaveOccurred()) - errorPage = &ErrorPage{ - Template: tmpl, - ProxyPrefix: "/prefix/", - Footer: "Custom Footer Text", - Version: "v0.0.0-test", + errorPage = &errorPageWriter{ + template: tmpl, + proxyPrefix: "/prefix/", + footer: "Custom Footer Text", + version: "v0.0.0-test", } }) - Context("Render", func() { + Context("WriteErrorPage", func() { It("Writes the template to the response writer", func() { recorder := httptest.NewRecorder() - errorPage.Render(recorder, 403, "/redirect", "Access Denied") + errorPage.WriteErrorPage(recorder, 403, "/redirect", "Access Denied") body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -37,7 +37,7 @@ var _ = Describe("Error Page", func() { It("With a different code, uses the stock message for the correct code", func() { recorder := httptest.NewRecorder() - errorPage.Render(recorder, 500, "/redirect", "Access Denied") + errorPage.WriteErrorPage(recorder, 500, "/redirect", "Access Denied") body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -46,7 +46,7 @@ var _ = Describe("Error Page", func() { It("With a message override, uses the message", func() { recorder := httptest.NewRecorder() - errorPage.Render(recorder, 403, "/redirect", "Access Denied", "An extra message: %s", "with more context.") + errorPage.WriteErrorPage(recorder, 403, "/redirect", "Access Denied", "An extra message: %s", "with more context.") body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -71,14 +71,14 @@ var _ = Describe("Error Page", func() { tmpl, err := template.New("").Parse("{{.Message}}") Expect(err).ToNot(HaveOccurred()) - errorPage.Template = tmpl - errorPage.Debug = true + errorPage.template = tmpl + errorPage.debug = true }) - Context("Render", func() { + Context("WriteErrorPage", func() { It("Writes the detailed error in place of the message", func() { recorder := httptest.NewRecorder() - errorPage.Render(recorder, 403, "/redirect", "Debug error") + errorPage.WriteErrorPage(recorder, 403, "/redirect", "Debug error") body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/app/pagewriter.go b/pkg/app/pagewriter.go new file mode 100644 index 00000000..84512789 --- /dev/null +++ b/pkg/app/pagewriter.go @@ -0,0 +1,85 @@ +package app + +import ( + "fmt" + "net/http" +) + +// PageWriter is an interface for rendering html templates for both sign-in and +// error pages. +// It can also be used to write errors for the http.ReverseProxy used in the +// upstream package. +type PageWriter interface { + WriteSignInPage(rw http.ResponseWriter, redirectURL string) + WriteErrorPage(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) + ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) +} + +// pageWriter implements the PageWriter interface +type pageWriter struct { + *errorPageWriter + *signInPageWriter +} + +// PageWriterOpts contains all options required to configure the template +// rendering within OAuth2 Proxy. +type PageWriterOpts struct { + // TemplatesPath is the path from which to load custom templates for the sign-in and error pages. + TemplatesPath string + + // ProxyPrefix is the prefix under which OAuth2 Proxy pages are served. + ProxyPrefix string + + // Footer is the footer to be displayed at the bottom of the page. + // If not set, a default footer will be used. + Footer string + + // Version is the OAuth2 Proxy version to be used in the default footer. + Version string + + // Debug determines whether errors pages should be rendered with detailed + // errors. + Debug bool + + // DisplayLoginForm determines whether or not the basic auth password form is displayed on the sign-in page. + DisplayLoginForm bool + + // ProviderName is the name of the provider that should be displayed on the login button. + ProviderName string + + // SignInMessage is the messge displayed above the login button. + SignInMessage string +} + +// NewPageWriter constructs a PageWriter from the options given to allow +// rendering of sign-in and error pages. +func NewPageWriter(opts PageWriterOpts) (PageWriter, error) { + templates, err := loadTemplates(opts.TemplatesPath) + if err != nil { + return nil, fmt.Errorf("error loading templates: %v", err) + } + + errorPage := &errorPageWriter{ + template: templates.Lookup("error.html"), + proxyPrefix: opts.ProxyPrefix, + footer: opts.Footer, + version: opts.Version, + debug: opts.Debug, + } + + signInPage := &signInPageWriter{ + template: templates.Lookup("sign_in.html"), + errorPageWriter: errorPage, + proxyPrefix: opts.ProxyPrefix, + providerName: opts.ProviderName, + signInMessage: opts.SignInMessage, + footer: opts.Footer, + version: opts.Version, + displayLoginForm: opts.DisplayLoginForm, + } + + return &pageWriter{ + errorPageWriter: errorPage, + signInPageWriter: signInPage, + }, nil +} diff --git a/pkg/app/pagewriter_test.go b/pkg/app/pagewriter_test.go new file mode 100644 index 00000000..7f7dee1f --- /dev/null +++ b/pkg/app/pagewriter_test.go @@ -0,0 +1,126 @@ +package app + +import ( + "io/ioutil" + "net/http/httptest" + "os" + "path/filepath" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("PageWriter", func() { + Context("NewPageWriter", func() { + var writer PageWriter + var opts PageWriterOpts + + BeforeEach(func() { + opts = PageWriterOpts{ + TemplatesPath: "", + ProxyPrefix: "/prefix", + Footer: "