Wrap templates and page rendering in PageWriter interface
This commit is contained in:
		
							parent
							
								
									dba6989054
								
							
						
					
					
						commit
						e8e2af73df
					
				| 
						 | 
					@ -101,8 +101,7 @@ type OAuthProxy struct {
 | 
				
			||||||
	sessionChain alice.Chain
 | 
						sessionChain alice.Chain
 | 
				
			||||||
	headersChain alice.Chain
 | 
						headersChain alice.Chain
 | 
				
			||||||
	preAuthChain alice.Chain
 | 
						preAuthChain alice.Chain
 | 
				
			||||||
	errorPage    *app.ErrorPage
 | 
						pageWriter   app.PageWriter
 | 
				
			||||||
	signInPage   *app.SignInPage
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | 
					// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | 
				
			||||||
| 
						 | 
					@ -112,20 +111,31 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
 | 
				
			||||||
		return nil, fmt.Errorf("error initialising session store: %v", err)
 | 
							return nil, fmt.Errorf("error initialising session store: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	templates, err := app.LoadTemplates(opts.Templates.Path)
 | 
						var basicAuthValidator basic.Validator
 | 
				
			||||||
 | 
						if opts.HtpasswdFile != "" {
 | 
				
			||||||
 | 
							logger.Printf("using htpasswd file: %s", opts.HtpasswdFile)
 | 
				
			||||||
 | 
							var err error
 | 
				
			||||||
 | 
							basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
		return nil, fmt.Errorf("error loading templates: %v", err)
 | 
								return nil, fmt.Errorf("could not load htpasswdfile: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	errorPage := &app.ErrorPage{
 | 
						pageWriter, err := app.NewPageWriter(app.PageWriterOpts{
 | 
				
			||||||
		Template:    templates.Lookup("error.html"),
 | 
							TemplatesPath:    opts.Templates.Path,
 | 
				
			||||||
		ProxyPrefix:      opts.ProxyPrefix,
 | 
							ProxyPrefix:      opts.ProxyPrefix,
 | 
				
			||||||
		Footer:           opts.Templates.Footer,
 | 
							Footer:           opts.Templates.Footer,
 | 
				
			||||||
		Version:          VERSION,
 | 
							Version:          VERSION,
 | 
				
			||||||
		Debug:            opts.Templates.Debug,
 | 
							Debug:            opts.Templates.Debug,
 | 
				
			||||||
 | 
							ProviderName:     buildProviderName(opts.GetProvider(), opts.ProviderName),
 | 
				
			||||||
 | 
							SignInMessage:    buildSignInMessage(opts),
 | 
				
			||||||
 | 
							DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("error initialising page writer: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), errorPage.ProxyErrorHandler)
 | 
						upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter.ProxyErrorHandler)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, fmt.Errorf("error initialising upstream proxy: %v", err)
 | 
							return nil, fmt.Errorf("error initialising upstream proxy: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -158,27 +168,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var basicAuthValidator basic.Validator
 | 
					 | 
				
			||||||
	if opts.HtpasswdFile != "" {
 | 
					 | 
				
			||||||
		logger.Printf("using htpasswd file: %s", opts.HtpasswdFile)
 | 
					 | 
				
			||||||
		var err error
 | 
					 | 
				
			||||||
		basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, fmt.Errorf("could not load htpasswdfile: %v", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	signInPage := &app.SignInPage{
 | 
					 | 
				
			||||||
		Template:         templates.Lookup("sign_in.html"),
 | 
					 | 
				
			||||||
		ErrorPage:        errorPage,
 | 
					 | 
				
			||||||
		ProxyPrefix:      opts.ProxyPrefix,
 | 
					 | 
				
			||||||
		ProviderName:     buildProviderName(opts.GetProvider(), opts.ProviderName),
 | 
					 | 
				
			||||||
		SignInMessage:    buildSignInMessage(opts),
 | 
					 | 
				
			||||||
		Footer:           opts.Templates.Footer,
 | 
					 | 
				
			||||||
		Version:          VERSION,
 | 
					 | 
				
			||||||
		DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	allowedRoutes, err := buildRoutesAllowlist(opts)
 | 
						allowedRoutes, err := buildRoutesAllowlist(opts)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
| 
						 | 
					@ -232,8 +221,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
 | 
				
			||||||
		sessionChain:       sessionChain,
 | 
							sessionChain:       sessionChain,
 | 
				
			||||||
		headersChain:       headersChain,
 | 
							headersChain:       headersChain,
 | 
				
			||||||
		preAuthChain:       preAuthChain,
 | 
							preAuthChain:       preAuthChain,
 | 
				
			||||||
		errorPage:          errorPage,
 | 
							pageWriter:         pageWriter,
 | 
				
			||||||
		signInPage:         signInPage,
 | 
					 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -540,7 +528,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i
 | 
				
			||||||
		redirectURL = "/"
 | 
							redirectURL = "/"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p.errorPage.Render(rw, code, redirectURL, appError, messages...)
 | 
						p.pageWriter.WriteErrorPage(rw, code, redirectURL, appError, messages...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// IsAllowedRequest is used to check if auth should be skipped for this request
 | 
					// IsAllowedRequest is used to check if auth should be skipped for this request
 | 
				
			||||||
| 
						 | 
					@ -601,7 +589,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
 | 
				
			||||||
		redirectURL = "/"
 | 
							redirectURL = "/"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p.signInPage.Render(rw, redirectURL)
 | 
						p.pageWriter.WriteSignInPage(rw, redirectURL)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ManualSignIn handles basic auth logins to the proxy
 | 
					// ManualSignIn handles basic auth logins to the proxy
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,30 +17,30 @@ var errorMessages = map[int]string{
 | 
				
			||||||
	http.StatusUnauthorized:        "You need to be logged in to access this resource.",
 | 
						http.StatusUnauthorized:        "You need to be logged in to access this resource.",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ErrorPage is used to render error pages.
 | 
					// errorPageWriter is used to render error pages.
 | 
				
			||||||
type ErrorPage struct {
 | 
					type errorPageWriter struct {
 | 
				
			||||||
	// Template is the error page HTML template.
 | 
						// template is the error page HTML template.
 | 
				
			||||||
	Template *template.Template
 | 
						template *template.Template
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ProxyPrefix is the prefix under which OAuth2 Proxy pages are served.
 | 
						// proxyPrefix is the prefix under which OAuth2 Proxy pages are served.
 | 
				
			||||||
	ProxyPrefix string
 | 
						proxyPrefix string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Footer is the footer to be displayed at the bottom of the page.
 | 
						// footer is the footer to be displayed at the bottom of the page.
 | 
				
			||||||
	// If not set, a default footer will be used.
 | 
						// If not set, a default footer will be used.
 | 
				
			||||||
	Footer string
 | 
						footer string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Version is the OAuth2 Proxy version to be used in the default footer.
 | 
						// version is the OAuth2 Proxy version to be used in the default footer.
 | 
				
			||||||
	Version string
 | 
						version string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Debug determines whether errors pages should be rendered with detailed
 | 
						// debug determines whether errors pages should be rendered with detailed
 | 
				
			||||||
	// errors.
 | 
						// errors.
 | 
				
			||||||
	Debug bool
 | 
						debug bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Render writes an error page to the given response writer.
 | 
					// WriteErrorPage writes an error page to the given response writer.
 | 
				
			||||||
// It uses the passed redirectURL to give users the option to go back to where
 | 
					// It uses the passed redirectURL to give users the option to go back to where
 | 
				
			||||||
// they originally came from or try signing in again.
 | 
					// they originally came from or try signing in again.
 | 
				
			||||||
func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) {
 | 
					func (e *errorPageWriter) WriteErrorPage(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) {
 | 
				
			||||||
	rw.WriteHeader(status)
 | 
						rw.WriteHeader(status)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// We allow unescaped template.HTML since it is user configured options
 | 
						// We allow unescaped template.HTML since it is user configured options
 | 
				
			||||||
| 
						 | 
					@ -56,14 +56,14 @@ func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL strin
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		Title:       http.StatusText(status),
 | 
							Title:       http.StatusText(status),
 | 
				
			||||||
		Message:     e.getMessage(status, appError, messages...),
 | 
							Message:     e.getMessage(status, appError, messages...),
 | 
				
			||||||
		ProxyPrefix: e.ProxyPrefix,
 | 
							ProxyPrefix: e.proxyPrefix,
 | 
				
			||||||
		StatusCode:  status,
 | 
							StatusCode:  status,
 | 
				
			||||||
		Redirect:    redirectURL,
 | 
							Redirect:    redirectURL,
 | 
				
			||||||
		Footer:      template.HTML(e.Footer),
 | 
							Footer:      template.HTML(e.footer),
 | 
				
			||||||
		Version:     e.Version,
 | 
							Version:     e.version,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := e.Template.Execute(rw, data); err != nil {
 | 
						if err := e.template.Execute(rw, data); err != nil {
 | 
				
			||||||
		logger.Printf("Error rendering error template: %v", err)
 | 
							logger.Printf("Error rendering error template: %v", err)
 | 
				
			||||||
		http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 | 
							http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -72,18 +72,18 @@ func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL strin
 | 
				
			||||||
// ProxyErrorHandler is used by the upstream ReverseProxy to render error pages
 | 
					// ProxyErrorHandler is used by the upstream ReverseProxy to render error pages
 | 
				
			||||||
// when there are issues with upstream servers.
 | 
					// when there are issues with upstream servers.
 | 
				
			||||||
// It is expected to always render a bad gateway error.
 | 
					// It is expected to always render a bad gateway error.
 | 
				
			||||||
func (e *ErrorPage) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) {
 | 
					func (e *errorPageWriter) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) {
 | 
				
			||||||
	logger.Errorf("Error proxying to upstream server: %v", proxyErr)
 | 
						logger.Errorf("Error proxying to upstream server: %v", proxyErr)
 | 
				
			||||||
	e.Render(rw, http.StatusBadGateway, "", proxyErr.Error(), "There was a problem connecting to the upstream server.")
 | 
						e.WriteErrorPage(rw, http.StatusBadGateway, "", proxyErr.Error(), "There was a problem connecting to the upstream server.")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// getMessage creates the message for the template parameters.
 | 
					// getMessage creates the message for the template parameters.
 | 
				
			||||||
// If the ErrorPage.Debug is enabled, the application error takes precedence.
 | 
					// If the errorPagewriter.Debug is enabled, the application error takes precedence.
 | 
				
			||||||
// Otherwise, any messages will be used.
 | 
					// Otherwise, any messages will be used.
 | 
				
			||||||
// The first message is expected to be a format string.
 | 
					// The first message is expected to be a format string.
 | 
				
			||||||
// If no messages are supplied, a default error message will be used.
 | 
					// If no messages are supplied, a default error message will be used.
 | 
				
			||||||
func (e *ErrorPage) getMessage(status int, appError string, messages ...interface{}) string {
 | 
					func (e *errorPageWriter) getMessage(status int, appError string, messages ...interface{}) string {
 | 
				
			||||||
	if e.Debug {
 | 
						if e.debug {
 | 
				
			||||||
		return appError
 | 
							return appError
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(messages) > 0 {
 | 
						if len(messages) > 0 {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,25 +10,25 @@ import (
 | 
				
			||||||
	. "github.com/onsi/gomega"
 | 
						. "github.com/onsi/gomega"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var _ = Describe("Error Page", func() {
 | 
					var _ = Describe("Error Page Writer", func() {
 | 
				
			||||||
	var errorPage *ErrorPage
 | 
						var errorPage *errorPageWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	BeforeEach(func() {
 | 
						BeforeEach(func() {
 | 
				
			||||||
		tmpl, err := template.New("").Parse("{{.Title}} {{.Message}} {{.ProxyPrefix}} {{.StatusCode}} {{.Redirect}} {{.Footer}} {{.Version}}")
 | 
							tmpl, err := template.New("").Parse("{{.Title}} {{.Message}} {{.ProxyPrefix}} {{.StatusCode}} {{.Redirect}} {{.Footer}} {{.Version}}")
 | 
				
			||||||
		Expect(err).ToNot(HaveOccurred())
 | 
							Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		errorPage = &ErrorPage{
 | 
							errorPage = &errorPageWriter{
 | 
				
			||||||
			Template:    tmpl,
 | 
								template:    tmpl,
 | 
				
			||||||
			ProxyPrefix: "/prefix/",
 | 
								proxyPrefix: "/prefix/",
 | 
				
			||||||
			Footer:      "Custom Footer Text",
 | 
								footer:      "Custom Footer Text",
 | 
				
			||||||
			Version:     "v0.0.0-test",
 | 
								version:     "v0.0.0-test",
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	Context("Render", func() {
 | 
						Context("WriteErrorPage", func() {
 | 
				
			||||||
		It("Writes the template to the response writer", func() {
 | 
							It("Writes the template to the response writer", func() {
 | 
				
			||||||
			recorder := httptest.NewRecorder()
 | 
								recorder := httptest.NewRecorder()
 | 
				
			||||||
			errorPage.Render(recorder, 403, "/redirect", "Access Denied")
 | 
								errorPage.WriteErrorPage(recorder, 403, "/redirect", "Access Denied")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
								body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
			Expect(err).ToNot(HaveOccurred())
 | 
								Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
| 
						 | 
					@ -37,7 +37,7 @@ var _ = Describe("Error Page", func() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		It("With a different code, uses the stock message for the correct code", func() {
 | 
							It("With a different code, uses the stock message for the correct code", func() {
 | 
				
			||||||
			recorder := httptest.NewRecorder()
 | 
								recorder := httptest.NewRecorder()
 | 
				
			||||||
			errorPage.Render(recorder, 500, "/redirect", "Access Denied")
 | 
								errorPage.WriteErrorPage(recorder, 500, "/redirect", "Access Denied")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
								body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
			Expect(err).ToNot(HaveOccurred())
 | 
								Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
| 
						 | 
					@ -46,7 +46,7 @@ var _ = Describe("Error Page", func() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		It("With a message override, uses the message", func() {
 | 
							It("With a message override, uses the message", func() {
 | 
				
			||||||
			recorder := httptest.NewRecorder()
 | 
								recorder := httptest.NewRecorder()
 | 
				
			||||||
			errorPage.Render(recorder, 403, "/redirect", "Access Denied", "An extra message: %s", "with more context.")
 | 
								errorPage.WriteErrorPage(recorder, 403, "/redirect", "Access Denied", "An extra message: %s", "with more context.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
								body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
			Expect(err).ToNot(HaveOccurred())
 | 
								Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
| 
						 | 
					@ -71,14 +71,14 @@ var _ = Describe("Error Page", func() {
 | 
				
			||||||
			tmpl, err := template.New("").Parse("{{.Message}}")
 | 
								tmpl, err := template.New("").Parse("{{.Message}}")
 | 
				
			||||||
			Expect(err).ToNot(HaveOccurred())
 | 
								Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			errorPage.Template = tmpl
 | 
								errorPage.template = tmpl
 | 
				
			||||||
			errorPage.Debug = true
 | 
								errorPage.debug = true
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Context("Render", func() {
 | 
							Context("WriteErrorPage", func() {
 | 
				
			||||||
			It("Writes the detailed error in place of the message", func() {
 | 
								It("Writes the detailed error in place of the message", func() {
 | 
				
			||||||
				recorder := httptest.NewRecorder()
 | 
									recorder := httptest.NewRecorder()
 | 
				
			||||||
				errorPage.Render(recorder, 403, "/redirect", "Debug error")
 | 
									errorPage.WriteErrorPage(recorder, 403, "/redirect", "Debug error")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
									body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
				Expect(err).ToNot(HaveOccurred())
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,85 @@
 | 
				
			||||||
 | 
					package app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// PageWriter is an interface for rendering html templates for both sign-in and
 | 
				
			||||||
 | 
					// error pages.
 | 
				
			||||||
 | 
					// It can also be used to write errors for the http.ReverseProxy used in the
 | 
				
			||||||
 | 
					// upstream package.
 | 
				
			||||||
 | 
					type PageWriter interface {
 | 
				
			||||||
 | 
						WriteSignInPage(rw http.ResponseWriter, redirectURL string)
 | 
				
			||||||
 | 
						WriteErrorPage(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{})
 | 
				
			||||||
 | 
						ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// pageWriter implements the PageWriter interface
 | 
				
			||||||
 | 
					type pageWriter struct {
 | 
				
			||||||
 | 
						*errorPageWriter
 | 
				
			||||||
 | 
						*signInPageWriter
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// PageWriterOpts contains all options required to configure the template
 | 
				
			||||||
 | 
					// rendering within OAuth2 Proxy.
 | 
				
			||||||
 | 
					type PageWriterOpts struct {
 | 
				
			||||||
 | 
						// TemplatesPath is the path from which to load custom templates for the sign-in and error pages.
 | 
				
			||||||
 | 
						TemplatesPath string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// ProxyPrefix is the prefix under which OAuth2 Proxy pages are served.
 | 
				
			||||||
 | 
						ProxyPrefix string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Footer is the footer to be displayed at the bottom of the page.
 | 
				
			||||||
 | 
						// If not set, a default footer will be used.
 | 
				
			||||||
 | 
						Footer string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Version is the OAuth2 Proxy version to be used in the default footer.
 | 
				
			||||||
 | 
						Version string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Debug determines whether errors pages should be rendered with detailed
 | 
				
			||||||
 | 
						// errors.
 | 
				
			||||||
 | 
						Debug bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// DisplayLoginForm determines whether or not the basic auth password form is displayed on the sign-in page.
 | 
				
			||||||
 | 
						DisplayLoginForm bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// ProviderName is the name of the provider that should be displayed on the login button.
 | 
				
			||||||
 | 
						ProviderName string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// SignInMessage is the messge displayed above the login button.
 | 
				
			||||||
 | 
						SignInMessage string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewPageWriter constructs a PageWriter from the options given to allow
 | 
				
			||||||
 | 
					// rendering of sign-in and error pages.
 | 
				
			||||||
 | 
					func NewPageWriter(opts PageWriterOpts) (PageWriter, error) {
 | 
				
			||||||
 | 
						templates, err := loadTemplates(opts.TemplatesPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("error loading templates: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						errorPage := &errorPageWriter{
 | 
				
			||||||
 | 
							template:    templates.Lookup("error.html"),
 | 
				
			||||||
 | 
							proxyPrefix: opts.ProxyPrefix,
 | 
				
			||||||
 | 
							footer:      opts.Footer,
 | 
				
			||||||
 | 
							version:     opts.Version,
 | 
				
			||||||
 | 
							debug:       opts.Debug,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						signInPage := &signInPageWriter{
 | 
				
			||||||
 | 
							template:         templates.Lookup("sign_in.html"),
 | 
				
			||||||
 | 
							errorPageWriter:  errorPage,
 | 
				
			||||||
 | 
							proxyPrefix:      opts.ProxyPrefix,
 | 
				
			||||||
 | 
							providerName:     opts.ProviderName,
 | 
				
			||||||
 | 
							signInMessage:    opts.SignInMessage,
 | 
				
			||||||
 | 
							footer:           opts.Footer,
 | 
				
			||||||
 | 
							version:          opts.Version,
 | 
				
			||||||
 | 
							displayLoginForm: opts.DisplayLoginForm,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &pageWriter{
 | 
				
			||||||
 | 
							errorPageWriter:  errorPage,
 | 
				
			||||||
 | 
							signInPageWriter: signInPage,
 | 
				
			||||||
 | 
						}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,126 @@
 | 
				
			||||||
 | 
					package app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						. "github.com/onsi/ginkgo"
 | 
				
			||||||
 | 
						. "github.com/onsi/gomega"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var _ = Describe("PageWriter", func() {
 | 
				
			||||||
 | 
						Context("NewPageWriter", func() {
 | 
				
			||||||
 | 
							var writer PageWriter
 | 
				
			||||||
 | 
							var opts PageWriterOpts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							BeforeEach(func() {
 | 
				
			||||||
 | 
								opts = PageWriterOpts{
 | 
				
			||||||
 | 
									TemplatesPath:    "",
 | 
				
			||||||
 | 
									ProxyPrefix:      "/prefix",
 | 
				
			||||||
 | 
									Footer:           "<Footer>",
 | 
				
			||||||
 | 
									Version:          "<Version>",
 | 
				
			||||||
 | 
									Debug:            false,
 | 
				
			||||||
 | 
									DisplayLoginForm: false,
 | 
				
			||||||
 | 
									ProviderName:     "<ProviderName>",
 | 
				
			||||||
 | 
									SignInMessage:    "<SignInMessage>",
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							Context("With no custom templates", func() {
 | 
				
			||||||
 | 
								BeforeEach(func() {
 | 
				
			||||||
 | 
									var err error
 | 
				
			||||||
 | 
									writer, err = NewPageWriter(opts)
 | 
				
			||||||
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								It("Writes the default error template", func() {
 | 
				
			||||||
 | 
									recorder := httptest.NewRecorder()
 | 
				
			||||||
 | 
									writer.WriteErrorPage(recorder, 500, "/redirect", "Some debug error")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
									Expect(string(body)).To(HavePrefix("\n<!DOCTYPE html>"))
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								It("Writes the default sign in template", func() {
 | 
				
			||||||
 | 
									recorder := httptest.NewRecorder()
 | 
				
			||||||
 | 
									writer.WriteSignInPage(recorder, "/redirect")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
									Expect(string(body)).To(HavePrefix("\n<!DOCTYPE html>"))
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							Context("With custom templates", func() {
 | 
				
			||||||
 | 
								var customDir string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								BeforeEach(func() {
 | 
				
			||||||
 | 
									var err error
 | 
				
			||||||
 | 
									customDir, err = ioutil.TempDir("", "oauth2-proxy-pagewriter-test")
 | 
				
			||||||
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									templateHTML := `Custom Template`
 | 
				
			||||||
 | 
									signInFile := filepath.Join(customDir, signInTemplateName)
 | 
				
			||||||
 | 
									Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0600)).To(Succeed())
 | 
				
			||||||
 | 
									errorFile := filepath.Join(customDir, errorTemplateName)
 | 
				
			||||||
 | 
									Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0600)).To(Succeed())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									opts.TemplatesPath = customDir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									writer, err = NewPageWriter(opts)
 | 
				
			||||||
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								AfterEach(func() {
 | 
				
			||||||
 | 
									Expect(os.RemoveAll(customDir)).To(Succeed())
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								It("Writes the custom error template", func() {
 | 
				
			||||||
 | 
									recorder := httptest.NewRecorder()
 | 
				
			||||||
 | 
									writer.WriteErrorPage(recorder, 500, "/redirect", "Some debug error")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
									Expect(string(body)).To(Equal("Custom Template"))
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								It("Writes the custom sign in template", func() {
 | 
				
			||||||
 | 
									recorder := httptest.NewRecorder()
 | 
				
			||||||
 | 
									writer.WriteSignInPage(recorder, "/redirect")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
									Expect(string(body)).To(Equal("Custom Template"))
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							Context("With an invalid custom template", func() {
 | 
				
			||||||
 | 
								var customDir string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								BeforeEach(func() {
 | 
				
			||||||
 | 
									var err error
 | 
				
			||||||
 | 
									customDir, err = ioutil.TempDir("", "oauth2-proxy-pagewriter-test")
 | 
				
			||||||
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									templateHTML := `{{ Custom Broken Template`
 | 
				
			||||||
 | 
									signInFile := filepath.Join(customDir, signInTemplateName)
 | 
				
			||||||
 | 
									Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0600)).To(Succeed())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									opts.TemplatesPath = customDir
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								AfterEach(func() {
 | 
				
			||||||
 | 
									Expect(os.RemoveAll(customDir)).To(Succeed())
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								It("Should return an error", func() {
 | 
				
			||||||
 | 
									writer, err := NewPageWriter(opts)
 | 
				
			||||||
 | 
									Expect(err).To(MatchError(ContainSubstring("template: sign_in.html:1: function \"Custom\" not defined")))
 | 
				
			||||||
 | 
									Expect(writer).To(BeNil())
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					})
 | 
				
			||||||
| 
						 | 
					@ -7,37 +7,37 @@ import (
 | 
				
			||||||
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
 | 
						"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SignInPage is used to render sign-in pages.
 | 
					// signInPageWriter is used to render sign-in pages.
 | 
				
			||||||
type SignInPage struct {
 | 
					type signInPageWriter struct {
 | 
				
			||||||
	// Template is the sign-in page HTML template.
 | 
						// Template is the sign-in page HTML template.
 | 
				
			||||||
	Template *template.Template
 | 
						template *template.Template
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ErrorPage is used to render an error if there are problems with rendering the sign-in page.
 | 
						// errorPageWriter is used to render an error if there are problems with rendering the sign-in page.
 | 
				
			||||||
	ErrorPage *ErrorPage
 | 
						errorPageWriter *errorPageWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ProxyPrefix is the prefix under which OAuth2 Proxy pages are served.
 | 
						// ProxyPrefix is the prefix under which OAuth2 Proxy pages are served.
 | 
				
			||||||
	ProxyPrefix string
 | 
						proxyPrefix string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ProviderName is the name of the provider that should be displayed on the login button.
 | 
						// ProviderName is the name of the provider that should be displayed on the login button.
 | 
				
			||||||
	ProviderName string
 | 
						providerName string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// SignInMessage is the messge displayed above the login button.
 | 
						// SignInMessage is the messge displayed above the login button.
 | 
				
			||||||
	SignInMessage string
 | 
						signInMessage string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Footer is the footer to be displayed at the bottom of the page.
 | 
						// Footer is the footer to be displayed at the bottom of the page.
 | 
				
			||||||
	// If not set, a default footer will be used.
 | 
						// If not set, a default footer will be used.
 | 
				
			||||||
	Footer string
 | 
						footer string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Version is the OAuth2 Proxy version to be used in the default footer.
 | 
						// Version is the OAuth2 Proxy version to be used in the default footer.
 | 
				
			||||||
	Version string
 | 
						version string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// DisplayLoginForm determines whether or not the basic auth password form is displayed on the sign-in page.
 | 
						// DisplayLoginForm determines whether or not the basic auth password form is displayed on the sign-in page.
 | 
				
			||||||
	DisplayLoginForm bool
 | 
						displayLoginForm bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Render writes the sign-in page to the given response writer.
 | 
					// WriteSignInPage writes the sign-in page to the given response writer.
 | 
				
			||||||
// It uses the redirectURL to be able to set the final destination for the user post login.
 | 
					// It uses the redirectURL to be able to set the final destination for the user post login.
 | 
				
			||||||
func (s *SignInPage) Render(rw http.ResponseWriter, redirectURL string) {
 | 
					func (s *signInPageWriter) WriteSignInPage(rw http.ResponseWriter, redirectURL string) {
 | 
				
			||||||
	// We allow unescaped template.HTML since it is user configured options
 | 
						// We allow unescaped template.HTML since it is user configured options
 | 
				
			||||||
	/* #nosec G203 */
 | 
						/* #nosec G203 */
 | 
				
			||||||
	t := struct {
 | 
						t := struct {
 | 
				
			||||||
| 
						 | 
					@ -49,18 +49,18 @@ func (s *SignInPage) Render(rw http.ResponseWriter, redirectURL string) {
 | 
				
			||||||
		ProxyPrefix   string
 | 
							ProxyPrefix   string
 | 
				
			||||||
		Footer        template.HTML
 | 
							Footer        template.HTML
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		ProviderName:  s.ProviderName,
 | 
							ProviderName:  s.providerName,
 | 
				
			||||||
		SignInMessage: template.HTML(s.SignInMessage),
 | 
							SignInMessage: template.HTML(s.signInMessage),
 | 
				
			||||||
		CustomLogin:   s.DisplayLoginForm,
 | 
							CustomLogin:   s.displayLoginForm,
 | 
				
			||||||
		Redirect:      redirectURL,
 | 
							Redirect:      redirectURL,
 | 
				
			||||||
		Version:       s.Version,
 | 
							Version:       s.version,
 | 
				
			||||||
		ProxyPrefix:   s.ProxyPrefix,
 | 
							ProxyPrefix:   s.proxyPrefix,
 | 
				
			||||||
		Footer:        template.HTML(s.Footer),
 | 
							Footer:        template.HTML(s.footer),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := s.Template.Execute(rw, t)
 | 
						err := s.template.Execute(rw, t)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Printf("Error rendering sign-in template: %v", err)
 | 
							logger.Printf("Error rendering sign-in template: %v", err)
 | 
				
			||||||
		s.ErrorPage.Render(rw, http.StatusInternalServerError, redirectURL, err.Error())
 | 
							s.errorPageWriter.WriteErrorPage(rw, http.StatusInternalServerError, redirectURL, err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,35 +9,35 @@ import (
 | 
				
			||||||
	. "github.com/onsi/gomega"
 | 
						. "github.com/onsi/gomega"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var _ = Describe("SignIn Page", func() {
 | 
					var _ = Describe("SignIn Page Writer", func() {
 | 
				
			||||||
	var signInPage *SignInPage
 | 
						var signInPage *signInPageWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	BeforeEach(func() {
 | 
						BeforeEach(func() {
 | 
				
			||||||
		errorTmpl, err := template.New("").Parse("{{.Title}}")
 | 
							errorTmpl, err := template.New("").Parse("{{.Title}}")
 | 
				
			||||||
		Expect(err).ToNot(HaveOccurred())
 | 
							Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
		errorPage := &ErrorPage{
 | 
							errorPage := &errorPageWriter{
 | 
				
			||||||
			Template: errorTmpl,
 | 
								template: errorTmpl,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		tmpl, err := template.New("").Parse("{{.ProxyPrefix}} {{.ProviderName}} {{.SignInMessage}} {{.Footer}} {{.Version}} {{.Redirect}} {{.CustomLogin}}")
 | 
							tmpl, err := template.New("").Parse("{{.ProxyPrefix}} {{.ProviderName}} {{.SignInMessage}} {{.Footer}} {{.Version}} {{.Redirect}} {{.CustomLogin}}")
 | 
				
			||||||
		Expect(err).ToNot(HaveOccurred())
 | 
							Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		signInPage = &SignInPage{
 | 
							signInPage = &signInPageWriter{
 | 
				
			||||||
			Template:         tmpl,
 | 
								template:         tmpl,
 | 
				
			||||||
			ErrorPage:        errorPage,
 | 
								errorPageWriter:  errorPage,
 | 
				
			||||||
			ProxyPrefix:      "/prefix/",
 | 
								proxyPrefix:      "/prefix/",
 | 
				
			||||||
			ProviderName:     "My Provider",
 | 
								providerName:     "My Provider",
 | 
				
			||||||
			SignInMessage:    "Sign In Here",
 | 
								signInMessage:    "Sign In Here",
 | 
				
			||||||
			Footer:           "Custom Footer Text",
 | 
								footer:           "Custom Footer Text",
 | 
				
			||||||
			Version:          "v0.0.0-test",
 | 
								version:          "v0.0.0-test",
 | 
				
			||||||
			DisplayLoginForm: true,
 | 
								displayLoginForm: true,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	Context("Render", func() {
 | 
						Context("WriteSignInPage", func() {
 | 
				
			||||||
		It("Writes the template to the response writer", func() {
 | 
							It("Writes the template to the response writer", func() {
 | 
				
			||||||
			recorder := httptest.NewRecorder()
 | 
								recorder := httptest.NewRecorder()
 | 
				
			||||||
			signInPage.Render(recorder, "/redirect")
 | 
								signInPage.WriteSignInPage(recorder, "/redirect")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
								body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
			Expect(err).ToNot(HaveOccurred())
 | 
								Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
| 
						 | 
					@ -48,10 +48,10 @@ var _ = Describe("SignIn Page", func() {
 | 
				
			||||||
			// Overwrite the template with something bad
 | 
								// Overwrite the template with something bad
 | 
				
			||||||
			tmpl, err := template.New("").Parse("{{.Unknown}}")
 | 
								tmpl, err := template.New("").Parse("{{.Unknown}}")
 | 
				
			||||||
			Expect(err).ToNot(HaveOccurred())
 | 
								Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
			signInPage.Template = tmpl
 | 
								signInPage.template = tmpl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			recorder := httptest.NewRecorder()
 | 
								recorder := httptest.NewRecorder()
 | 
				
			||||||
			signInPage.Render(recorder, "/redirect")
 | 
								signInPage.WriteSignInPage(recorder, "/redirect")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
								body, err := ioutil.ReadAll(recorder.Result().Body)
 | 
				
			||||||
			Expect(err).ToNot(HaveOccurred())
 | 
								Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -206,10 +206,10 @@ const (
 | 
				
			||||||
{{end}}`
 | 
					{{end}}`
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// LoadTemplates adds the Sign In and Error templates from the custom template
 | 
					// loadTemplates adds the Sign In and Error templates from the custom template
 | 
				
			||||||
// directory, or uses the defaults if they do not exist or the custom directory
 | 
					// directory, or uses the defaults if they do not exist or the custom directory
 | 
				
			||||||
// is not provided.
 | 
					// is not provided.
 | 
				
			||||||
func LoadTemplates(customDir string) (*template.Template, error) {
 | 
					func loadTemplates(customDir string) (*template.Template, error) {
 | 
				
			||||||
	t := template.New("").Funcs(template.FuncMap{
 | 
						t := template.New("").Funcs(template.FuncMap{
 | 
				
			||||||
		"ToUpper": strings.ToUpper,
 | 
							"ToUpper": strings.ToUpper,
 | 
				
			||||||
		"ToLower": strings.ToLower,
 | 
							"ToLower": strings.ToLower,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,16 +21,16 @@ var _ = Describe("Templates", func() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}`
 | 
							templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}`
 | 
				
			||||||
		signInFile := filepath.Join(customDir, signInTemplateName)
 | 
							signInFile := filepath.Join(customDir, signInTemplateName)
 | 
				
			||||||
		Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0666)).To(Succeed())
 | 
							Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0600)).To(Succeed())
 | 
				
			||||||
		errorFile := filepath.Join(customDir, errorTemplateName)
 | 
							errorFile := filepath.Join(customDir, errorTemplateName)
 | 
				
			||||||
		Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0666)).To(Succeed())
 | 
							Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0600)).To(Succeed())
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	AfterEach(func() {
 | 
						AfterEach(func() {
 | 
				
			||||||
		Expect(os.RemoveAll(customDir)).To(Succeed())
 | 
							Expect(os.RemoveAll(customDir)).To(Succeed())
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	Context("LoadTemplates", func() {
 | 
						Context("loadTemplates", func() {
 | 
				
			||||||
		var data interface{}
 | 
							var data interface{}
 | 
				
			||||||
		var t *template.Template
 | 
							var t *template.Template
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,7 +73,7 @@ var _ = Describe("Templates", func() {
 | 
				
			||||||
		Context("With no custom directory", func() {
 | 
							Context("With no custom directory", func() {
 | 
				
			||||||
			BeforeEach(func() {
 | 
								BeforeEach(func() {
 | 
				
			||||||
				var err error
 | 
									var err error
 | 
				
			||||||
				t, err = LoadTemplates("")
 | 
									t, err = loadTemplates("")
 | 
				
			||||||
				Expect(err).ToNot(HaveOccurred())
 | 
									Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -94,7 +94,7 @@ var _ = Describe("Templates", func() {
 | 
				
			||||||
			Context("With both templates", func() {
 | 
								Context("With both templates", func() {
 | 
				
			||||||
				BeforeEach(func() {
 | 
									BeforeEach(func() {
 | 
				
			||||||
					var err error
 | 
										var err error
 | 
				
			||||||
					t, err = LoadTemplates(customDir)
 | 
										t, err = loadTemplates(customDir)
 | 
				
			||||||
					Expect(err).ToNot(HaveOccurred())
 | 
										Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -116,7 +116,7 @@ var _ = Describe("Templates", func() {
 | 
				
			||||||
					Expect(os.Remove(filepath.Join(customDir, errorTemplateName))).To(Succeed())
 | 
										Expect(os.Remove(filepath.Join(customDir, errorTemplateName))).To(Succeed())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					var err error
 | 
										var err error
 | 
				
			||||||
					t, err = LoadTemplates(customDir)
 | 
										t, err = loadTemplates(customDir)
 | 
				
			||||||
					Expect(err).ToNot(HaveOccurred())
 | 
										Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -138,7 +138,7 @@ var _ = Describe("Templates", func() {
 | 
				
			||||||
					Expect(os.Remove(filepath.Join(customDir, signInTemplateName))).To(Succeed())
 | 
										Expect(os.Remove(filepath.Join(customDir, signInTemplateName))).To(Succeed())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					var err error
 | 
										var err error
 | 
				
			||||||
					t, err = LoadTemplates(customDir)
 | 
										t, err = loadTemplates(customDir)
 | 
				
			||||||
					Expect(err).ToNot(HaveOccurred())
 | 
										Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -158,11 +158,11 @@ var _ = Describe("Templates", func() {
 | 
				
			||||||
			Context("With an invalid sign_in template", func() {
 | 
								Context("With an invalid sign_in template", func() {
 | 
				
			||||||
				BeforeEach(func() {
 | 
									BeforeEach(func() {
 | 
				
			||||||
					signInFile := filepath.Join(customDir, signInTemplateName)
 | 
										signInFile := filepath.Join(customDir, signInTemplateName)
 | 
				
			||||||
					Expect(ioutil.WriteFile(signInFile, []byte("{{"), 0666))
 | 
										Expect(ioutil.WriteFile(signInFile, []byte("{{"), 0600))
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				It("Should return an error when loading templates", func() {
 | 
									It("Should return an error when loading templates", func() {
 | 
				
			||||||
					t, err := LoadTemplates(customDir)
 | 
										t, err := loadTemplates(customDir)
 | 
				
			||||||
					Expect(err).To(MatchError(HavePrefix("could not add Sign In template:")))
 | 
										Expect(err).To(MatchError(HavePrefix("could not add Sign In template:")))
 | 
				
			||||||
					Expect(t).To(BeNil())
 | 
										Expect(t).To(BeNil())
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
| 
						 | 
					@ -171,11 +171,11 @@ var _ = Describe("Templates", func() {
 | 
				
			||||||
			Context("With an invalid error template", func() {
 | 
								Context("With an invalid error template", func() {
 | 
				
			||||||
				BeforeEach(func() {
 | 
									BeforeEach(func() {
 | 
				
			||||||
					errorFile := filepath.Join(customDir, errorTemplateName)
 | 
										errorFile := filepath.Join(customDir, errorTemplateName)
 | 
				
			||||||
					Expect(ioutil.WriteFile(errorFile, []byte("{{"), 0666))
 | 
										Expect(ioutil.WriteFile(errorFile, []byte("{{"), 0600))
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				It("Should return an error when loading templates", func() {
 | 
									It("Should return an error when loading templates", func() {
 | 
				
			||||||
					t, err := LoadTemplates(customDir)
 | 
										t, err := loadTemplates(customDir)
 | 
				
			||||||
					Expect(err).To(MatchError(HavePrefix("could not add Error template:")))
 | 
										Expect(err).To(MatchError(HavePrefix("could not add Error template:")))
 | 
				
			||||||
					Expect(t).To(BeNil())
 | 
										Expect(t).To(BeNil())
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue