Fix method names in ticket.go
This commit is contained in:
		
							parent
							
								
									9861d00558
								
							
						
					
					
						commit
						408ba1a1a5
					
				|  | @ -0,0 +1,50 @@ | ||||||
|  | package options | ||||||
|  | 
 | ||||||
|  | import "github.com/spf13/pflag" | ||||||
|  | 
 | ||||||
|  | // Templates includes options for configuring the sign in and error pages
 | ||||||
|  | // appearance.
 | ||||||
|  | type Templates struct { | ||||||
|  | 	// Path is the path to a folder containing a sign_in.html and an error.html
 | ||||||
|  | 	// template.
 | ||||||
|  | 	// These files will be used instead of the default templates if present.
 | ||||||
|  | 	// If either file is missing, the default will be used instead.
 | ||||||
|  | 	Path string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` | ||||||
|  | 
 | ||||||
|  | 	// Banner overides the default sign_in page banner text. If unspecified,
 | ||||||
|  | 	// the message will give users a list of allowed email domains.
 | ||||||
|  | 	Banner string `flag:"banner" cfg:"banner"` | ||||||
|  | 
 | ||||||
|  | 	// Footer overrides the default sign_in page footer text.
 | ||||||
|  | 	Footer string `flag:"footer" cfg:"footer"` | ||||||
|  | 
 | ||||||
|  | 	// DisplayLoginForm determines whether the sign_in page should render a
 | ||||||
|  | 	// password form if a static passwords file (htpasswd file) has been
 | ||||||
|  | 	// configured.
 | ||||||
|  | 	DisplayLoginForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` | ||||||
|  | 
 | ||||||
|  | 	// Debug renders detailed errors when an error page is shown.
 | ||||||
|  | 	// It is not advised to use this in production as errors may contain sensitive
 | ||||||
|  | 	// information.
 | ||||||
|  | 	// Use only for diagnosing backend errors.
 | ||||||
|  | 	Debug bool `flag:"show-debug-on-error" cfg:"show-debug-on-error"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func templatesFlagSet() *pflag.FlagSet { | ||||||
|  | 	flagSet := pflag.NewFlagSet("templates", pflag.ExitOnError) | ||||||
|  | 
 | ||||||
|  | 	flagSet.String("custom-templates-dir", "", "path to custom html templates") | ||||||
|  | 	flagSet.String("banner", "", "custom banner string. Use \"-\" to disable default banner.") | ||||||
|  | 	flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") | ||||||
|  | 	flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") | ||||||
|  | 	flagSet.Bool("show-debug-on-error", false, "show detailed error information on error pages (WARNING: this may contain sensitive information - do not use in production)") | ||||||
|  | 
 | ||||||
|  | 	return flagSet | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // templatesDefaults creates a Templates and populates it with any default values
 | ||||||
|  | func templatesDefaults() Templates { | ||||||
|  | 	return Templates{ | ||||||
|  | 		DisplayLoginForm: true, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,97 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"html/template" | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // errorMessages are default error messages for each of the the different
 | ||||||
|  | // http status codes expected to be rendered in the error page.
 | ||||||
|  | var errorMessages = map[int]string{ | ||||||
|  | 	http.StatusInternalServerError: "Oops! Something went wrong. For more information contact your server administrator.", | ||||||
|  | 	http.StatusNotFound:            "We could not find the resource you were looking for.", | ||||||
|  | 	http.StatusForbidden:           "You do not have permission to access this resource.", | ||||||
|  | 	http.StatusUnauthorized:        "You need to be logged in to access this resource.", | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // errorPageWriter is used to render error pages.
 | ||||||
|  | type errorPageWriter struct { | ||||||
|  | 	// template is the error page HTML template.
 | ||||||
|  | 	template *template.Template | ||||||
|  | 
 | ||||||
|  | 	// proxyPrefix is the prefix under which OAuth2 Proxy pages are served.
 | ||||||
|  | 	proxyPrefix string | ||||||
|  | 
 | ||||||
|  | 	// footer is the footer to be displayed at the bottom of the page.
 | ||||||
|  | 	// If not set, a default footer will be used.
 | ||||||
|  | 	footer string | ||||||
|  | 
 | ||||||
|  | 	// version is the OAuth2 Proxy version to be used in the default footer.
 | ||||||
|  | 	version string | ||||||
|  | 
 | ||||||
|  | 	// debug determines whether errors pages should be rendered with detailed
 | ||||||
|  | 	// errors.
 | ||||||
|  | 	debug bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WriteErrorPage writes an error page to the given response writer.
 | ||||||
|  | // It uses the passed redirectURL to give users the option to go back to where
 | ||||||
|  | // they originally came from or try signing in again.
 | ||||||
|  | func (e *errorPageWriter) WriteErrorPage(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) { | ||||||
|  | 	rw.WriteHeader(status) | ||||||
|  | 
 | ||||||
|  | 	// We allow unescaped template.HTML since it is user configured options
 | ||||||
|  | 	/* #nosec G203 */ | ||||||
|  | 	data := struct { | ||||||
|  | 		Title       string | ||||||
|  | 		Message     string | ||||||
|  | 		ProxyPrefix string | ||||||
|  | 		StatusCode  int | ||||||
|  | 		Redirect    string | ||||||
|  | 		Footer      template.HTML | ||||||
|  | 		Version     string | ||||||
|  | 	}{ | ||||||
|  | 		Title:       http.StatusText(status), | ||||||
|  | 		Message:     e.getMessage(status, appError, messages...), | ||||||
|  | 		ProxyPrefix: e.proxyPrefix, | ||||||
|  | 		StatusCode:  status, | ||||||
|  | 		Redirect:    redirectURL, | ||||||
|  | 		Footer:      template.HTML(e.footer), | ||||||
|  | 		Version:     e.version, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := e.template.Execute(rw, data); err != nil { | ||||||
|  | 		logger.Printf("Error rendering error template: %v", err) | ||||||
|  | 		http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ProxyErrorHandler is used by the upstream ReverseProxy to render error pages
 | ||||||
|  | // when there are issues with upstream servers.
 | ||||||
|  | // It is expected to always render a bad gateway error.
 | ||||||
|  | func (e *errorPageWriter) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { | ||||||
|  | 	logger.Errorf("Error proxying to upstream server: %v", proxyErr) | ||||||
|  | 	e.WriteErrorPage(rw, http.StatusBadGateway, "", proxyErr.Error(), "There was a problem connecting to the upstream server.") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getMessage creates the message for the template parameters.
 | ||||||
|  | // If the errorPagewriter.Debug is enabled, the application error takes precedence.
 | ||||||
|  | // Otherwise, any messages will be used.
 | ||||||
|  | // The first message is expected to be a format string.
 | ||||||
|  | // If no messages are supplied, a default error message will be used.
 | ||||||
|  | func (e *errorPageWriter) getMessage(status int, appError string, messages ...interface{}) string { | ||||||
|  | 	if e.debug { | ||||||
|  | 		return appError | ||||||
|  | 	} | ||||||
|  | 	if len(messages) > 0 { | ||||||
|  | 		format := fmt.Sprintf("%v", messages[0]) | ||||||
|  | 		return fmt.Sprintf(format, messages[1:]...) | ||||||
|  | 	} | ||||||
|  | 	if msg, ok := errorMessages[status]; ok { | ||||||
|  | 		return msg | ||||||
|  | 	} | ||||||
|  | 	return "Unknown error" | ||||||
|  | } | ||||||
|  | @ -0,0 +1,101 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"html/template" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Error Page Writer", func() { | ||||||
|  | 	var errorPage *errorPageWriter | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		tmpl, err := template.New("").Parse("{{.Title}} {{.Message}} {{.ProxyPrefix}} {{.StatusCode}} {{.Redirect}} {{.Footer}} {{.Version}}") | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		errorPage = &errorPageWriter{ | ||||||
|  | 			template:    tmpl, | ||||||
|  | 			proxyPrefix: "/prefix/", | ||||||
|  | 			footer:      "Custom Footer Text", | ||||||
|  | 			version:     "v0.0.0-test", | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("WriteErrorPage", func() { | ||||||
|  | 		It("Writes the template to the response writer", func() { | ||||||
|  | 			recorder := httptest.NewRecorder() | ||||||
|  | 			errorPage.WriteErrorPage(recorder, 403, "/redirect", "Access Denied") | ||||||
|  | 
 | ||||||
|  | 			body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(string(body)).To(Equal("Forbidden You do not have permission to access this resource. /prefix/ 403 /redirect Custom Footer Text v0.0.0-test")) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("With a different code, uses the stock message for the correct code", func() { | ||||||
|  | 			recorder := httptest.NewRecorder() | ||||||
|  | 			errorPage.WriteErrorPage(recorder, 500, "/redirect", "Access Denied") | ||||||
|  | 
 | ||||||
|  | 			body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(string(body)).To(Equal("Internal Server Error Oops! Something went wrong. For more information contact your server administrator. /prefix/ 500 /redirect Custom Footer Text v0.0.0-test")) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("With a message override, uses the message", func() { | ||||||
|  | 			recorder := httptest.NewRecorder() | ||||||
|  | 			errorPage.WriteErrorPage(recorder, 403, "/redirect", "Access Denied", "An extra message: %s", "with more context.") | ||||||
|  | 
 | ||||||
|  | 			body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(string(body)).To(Equal("Forbidden An extra message: with more context. /prefix/ 403 /redirect Custom Footer Text v0.0.0-test")) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("ProxyErrorHandler", func() { | ||||||
|  | 		It("Writes a bad gateway error the response writer", func() { | ||||||
|  | 			req := httptest.NewRequest("", "/bad-gateway", nil) | ||||||
|  | 			recorder := httptest.NewRecorder() | ||||||
|  | 			errorPage.ProxyErrorHandler(recorder, req, errors.New("some upstream error")) | ||||||
|  | 
 | ||||||
|  | 			body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(string(body)).To(Equal("Bad Gateway There was a problem connecting to the upstream server. /prefix/ 502  Custom Footer Text v0.0.0-test")) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("With Debug enabled", func() { | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			tmpl, err := template.New("").Parse("{{.Message}}") | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			errorPage.template = tmpl | ||||||
|  | 			errorPage.debug = true | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("WriteErrorPage", func() { | ||||||
|  | 			It("Writes the detailed error in place of the message", func() { | ||||||
|  | 				recorder := httptest.NewRecorder() | ||||||
|  | 				errorPage.WriteErrorPage(recorder, 403, "/redirect", "Debug error") | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal("Debug error")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("ProxyErrorHandler", func() { | ||||||
|  | 			It("Writes a bad gateway error the response writer", func() { | ||||||
|  | 				req := httptest.NewRequest("", "/bad-gateway", nil) | ||||||
|  | 				recorder := httptest.NewRecorder() | ||||||
|  | 				errorPage.ProxyErrorHandler(recorder, req, errors.New("some upstream error")) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal("some upstream error")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,85 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Writer is an interface for rendering html templates for both sign-in and
 | ||||||
|  | // error pages.
 | ||||||
|  | // It can also be used to write errors for the http.ReverseProxy used in the
 | ||||||
|  | // upstream package.
 | ||||||
|  | type Writer interface { | ||||||
|  | 	WriteSignInPage(rw http.ResponseWriter, redirectURL string) | ||||||
|  | 	WriteErrorPage(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) | ||||||
|  | 	ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // pageWriter implements the Writer interface
 | ||||||
|  | type pageWriter struct { | ||||||
|  | 	*errorPageWriter | ||||||
|  | 	*signInPageWriter | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Opts contains all options required to configure the template
 | ||||||
|  | // rendering within OAuth2 Proxy.
 | ||||||
|  | type Opts struct { | ||||||
|  | 	// TemplatesPath is the path from which to load custom templates for the sign-in and error pages.
 | ||||||
|  | 	TemplatesPath string | ||||||
|  | 
 | ||||||
|  | 	// ProxyPrefix is the prefix under which OAuth2 Proxy pages are served.
 | ||||||
|  | 	ProxyPrefix string | ||||||
|  | 
 | ||||||
|  | 	// Footer is the footer to be displayed at the bottom of the page.
 | ||||||
|  | 	// If not set, a default footer will be used.
 | ||||||
|  | 	Footer string | ||||||
|  | 
 | ||||||
|  | 	// Version is the OAuth2 Proxy version to be used in the default footer.
 | ||||||
|  | 	Version string | ||||||
|  | 
 | ||||||
|  | 	// Debug determines whether errors pages should be rendered with detailed
 | ||||||
|  | 	// errors.
 | ||||||
|  | 	Debug bool | ||||||
|  | 
 | ||||||
|  | 	// DisplayLoginForm determines whether or not the basic auth password form is displayed on the sign-in page.
 | ||||||
|  | 	DisplayLoginForm bool | ||||||
|  | 
 | ||||||
|  | 	// ProviderName is the name of the provider that should be displayed on the login button.
 | ||||||
|  | 	ProviderName string | ||||||
|  | 
 | ||||||
|  | 	// SignInMessage is the messge displayed above the login button.
 | ||||||
|  | 	SignInMessage string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewWriter constructs a Writer from the options given to allow
 | ||||||
|  | // rendering of sign-in and error pages.
 | ||||||
|  | func NewWriter(opts Opts) (Writer, error) { | ||||||
|  | 	templates, err := loadTemplates(opts.TemplatesPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error loading templates: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	errorPage := &errorPageWriter{ | ||||||
|  | 		template:    templates.Lookup("error.html"), | ||||||
|  | 		proxyPrefix: opts.ProxyPrefix, | ||||||
|  | 		footer:      opts.Footer, | ||||||
|  | 		version:     opts.Version, | ||||||
|  | 		debug:       opts.Debug, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	signInPage := &signInPageWriter{ | ||||||
|  | 		template:         templates.Lookup("sign_in.html"), | ||||||
|  | 		errorPageWriter:  errorPage, | ||||||
|  | 		proxyPrefix:      opts.ProxyPrefix, | ||||||
|  | 		providerName:     opts.ProviderName, | ||||||
|  | 		signInMessage:    opts.SignInMessage, | ||||||
|  | 		footer:           opts.Footer, | ||||||
|  | 		version:          opts.Version, | ||||||
|  | 		displayLoginForm: opts.DisplayLoginForm, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &pageWriter{ | ||||||
|  | 		errorPageWriter:  errorPage, | ||||||
|  | 		signInPageWriter: signInPage, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  | @ -0,0 +1,17 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestOptionsSuite(t *testing.T) { | ||||||
|  | 	logger.SetOutput(GinkgoWriter) | ||||||
|  | 	logger.SetErrOutput(GinkgoWriter) | ||||||
|  | 
 | ||||||
|  | 	RegisterFailHandler(Fail) | ||||||
|  | 	RunSpecs(t, "App Suite") | ||||||
|  | } | ||||||
|  | @ -0,0 +1,126 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Writer", func() { | ||||||
|  | 	Context("NewWriter", func() { | ||||||
|  | 		var writer Writer | ||||||
|  | 		var opts Opts | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			opts = Opts{ | ||||||
|  | 				TemplatesPath:    "", | ||||||
|  | 				ProxyPrefix:      "/prefix", | ||||||
|  | 				Footer:           "<Footer>", | ||||||
|  | 				Version:          "<Version>", | ||||||
|  | 				Debug:            false, | ||||||
|  | 				DisplayLoginForm: false, | ||||||
|  | 				ProviderName:     "<ProviderName>", | ||||||
|  | 				SignInMessage:    "<SignInMessage>", | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("With no custom templates", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				writer, err = NewWriter(opts) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Writes the default error template", func() { | ||||||
|  | 				recorder := httptest.NewRecorder() | ||||||
|  | 				writer.WriteErrorPage(recorder, 500, "/redirect", "Some debug error") | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Writes the default sign in template", func() { | ||||||
|  | 				recorder := httptest.NewRecorder() | ||||||
|  | 				writer.WriteSignInPage(recorder, "/redirect") | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("With custom templates", func() { | ||||||
|  | 			var customDir string | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				customDir, err = ioutil.TempDir("", "oauth2-proxy-pagewriter-test") | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				templateHTML := `Custom Template` | ||||||
|  | 				signInFile := filepath.Join(customDir, signInTemplateName) | ||||||
|  | 				Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0600)).To(Succeed()) | ||||||
|  | 				errorFile := filepath.Join(customDir, errorTemplateName) | ||||||
|  | 				Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0600)).To(Succeed()) | ||||||
|  | 
 | ||||||
|  | 				opts.TemplatesPath = customDir | ||||||
|  | 
 | ||||||
|  | 				writer, err = NewWriter(opts) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			AfterEach(func() { | ||||||
|  | 				Expect(os.RemoveAll(customDir)).To(Succeed()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Writes the custom error template", func() { | ||||||
|  | 				recorder := httptest.NewRecorder() | ||||||
|  | 				writer.WriteErrorPage(recorder, 500, "/redirect", "Some debug error") | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal("Custom Template")) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Writes the custom sign in template", func() { | ||||||
|  | 				recorder := httptest.NewRecorder() | ||||||
|  | 				writer.WriteSignInPage(recorder, "/redirect") | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal("Custom Template")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("With an invalid custom template", func() { | ||||||
|  | 			var customDir string | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				customDir, err = ioutil.TempDir("", "oauth2-proxy-pagewriter-test") | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				templateHTML := `{{ Custom Broken Template` | ||||||
|  | 				signInFile := filepath.Join(customDir, signInTemplateName) | ||||||
|  | 				Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0600)).To(Succeed()) | ||||||
|  | 
 | ||||||
|  | 				opts.TemplatesPath = customDir | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			AfterEach(func() { | ||||||
|  | 				Expect(os.RemoveAll(customDir)).To(Succeed()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Should return an error", func() { | ||||||
|  | 				writer, err := NewWriter(opts) | ||||||
|  | 				Expect(err).To(MatchError(ContainSubstring("template: sign_in.html:1: function \"Custom\" not defined"))) | ||||||
|  | 				Expect(writer).To(BeNil()) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,66 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"html/template" | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // signInPageWriter is used to render sign-in pages.
 | ||||||
|  | type signInPageWriter struct { | ||||||
|  | 	// Template is the sign-in page HTML template.
 | ||||||
|  | 	template *template.Template | ||||||
|  | 
 | ||||||
|  | 	// errorPageWriter is used to render an error if there are problems with rendering the sign-in page.
 | ||||||
|  | 	errorPageWriter *errorPageWriter | ||||||
|  | 
 | ||||||
|  | 	// ProxyPrefix is the prefix under which OAuth2 Proxy pages are served.
 | ||||||
|  | 	proxyPrefix string | ||||||
|  | 
 | ||||||
|  | 	// ProviderName is the name of the provider that should be displayed on the login button.
 | ||||||
|  | 	providerName string | ||||||
|  | 
 | ||||||
|  | 	// SignInMessage is the messge displayed above the login button.
 | ||||||
|  | 	signInMessage string | ||||||
|  | 
 | ||||||
|  | 	// Footer is the footer to be displayed at the bottom of the page.
 | ||||||
|  | 	// If not set, a default footer will be used.
 | ||||||
|  | 	footer string | ||||||
|  | 
 | ||||||
|  | 	// Version is the OAuth2 Proxy version to be used in the default footer.
 | ||||||
|  | 	version string | ||||||
|  | 
 | ||||||
|  | 	// DisplayLoginForm determines whether or not the basic auth password form is displayed on the sign-in page.
 | ||||||
|  | 	displayLoginForm bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WriteSignInPage writes the sign-in page to the given response writer.
 | ||||||
|  | // It uses the redirectURL to be able to set the final destination for the user post login.
 | ||||||
|  | func (s *signInPageWriter) WriteSignInPage(rw http.ResponseWriter, redirectURL string) { | ||||||
|  | 	// We allow unescaped template.HTML since it is user configured options
 | ||||||
|  | 	/* #nosec G203 */ | ||||||
|  | 	t := struct { | ||||||
|  | 		ProviderName  string | ||||||
|  | 		SignInMessage template.HTML | ||||||
|  | 		CustomLogin   bool | ||||||
|  | 		Redirect      string | ||||||
|  | 		Version       string | ||||||
|  | 		ProxyPrefix   string | ||||||
|  | 		Footer        template.HTML | ||||||
|  | 	}{ | ||||||
|  | 		ProviderName:  s.providerName, | ||||||
|  | 		SignInMessage: template.HTML(s.signInMessage), | ||||||
|  | 		CustomLogin:   s.displayLoginForm, | ||||||
|  | 		Redirect:      redirectURL, | ||||||
|  | 		Version:       s.version, | ||||||
|  | 		ProxyPrefix:   s.proxyPrefix, | ||||||
|  | 		Footer:        template.HTML(s.footer), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err := s.template.Execute(rw, t) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Printf("Error rendering sign-in template: %v", err) | ||||||
|  | 		s.errorPageWriter.WriteErrorPage(rw, http.StatusInternalServerError, redirectURL, err.Error()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,61 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"html/template" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("SignIn Page Writer", func() { | ||||||
|  | 	var signInPage *signInPageWriter | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		errorTmpl, err := template.New("").Parse("{{.Title}}") | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 		errorPage := &errorPageWriter{ | ||||||
|  | 			template: errorTmpl, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		tmpl, err := template.New("").Parse("{{.ProxyPrefix}} {{.ProviderName}} {{.SignInMessage}} {{.Footer}} {{.Version}} {{.Redirect}} {{.CustomLogin}}") | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		signInPage = &signInPageWriter{ | ||||||
|  | 			template:         tmpl, | ||||||
|  | 			errorPageWriter:  errorPage, | ||||||
|  | 			proxyPrefix:      "/prefix/", | ||||||
|  | 			providerName:     "My Provider", | ||||||
|  | 			signInMessage:    "Sign In Here", | ||||||
|  | 			footer:           "Custom Footer Text", | ||||||
|  | 			version:          "v0.0.0-test", | ||||||
|  | 			displayLoginForm: true, | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("WriteSignInPage", func() { | ||||||
|  | 		It("Writes the template to the response writer", func() { | ||||||
|  | 			recorder := httptest.NewRecorder() | ||||||
|  | 			signInPage.WriteSignInPage(recorder, "/redirect") | ||||||
|  | 
 | ||||||
|  | 			body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(string(body)).To(Equal("/prefix/ My Provider Sign In Here Custom Footer Text v0.0.0-test /redirect true")) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("Writes an error if the template can't be rendered", func() { | ||||||
|  | 			// Overwrite the template with something bad
 | ||||||
|  | 			tmpl, err := template.New("").Parse("{{.Unknown}}") | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			signInPage.template = tmpl | ||||||
|  | 
 | ||||||
|  | 			recorder := httptest.NewRecorder() | ||||||
|  | 			signInPage.WriteSignInPage(recorder, "/redirect") | ||||||
|  | 
 | ||||||
|  | 			body, err := ioutil.ReadAll(recorder.Result().Body) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(string(body)).To(Equal("Internal Server Error")) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,259 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"html/template" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	errorTemplateName  = "error.html" | ||||||
|  | 	signInTemplateName = "sign_in.html" | ||||||
|  | 
 | ||||||
|  | 	defaultErrorTemplate = `{{define "error.html"}} | ||||||
|  | <!DOCTYPE html> | ||||||
|  | <html lang="en" charset="utf-8"> | ||||||
|  | <head> | ||||||
|  |   <meta charset="utf-8"> | ||||||
|  |   <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no"> | ||||||
|  |   	<title>{{.StatusCode}} {{.Title}}</title> | ||||||
|  |   <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.1/css/bulma.min.css"> | ||||||
|  |   <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.2/css/all.min.css"> | ||||||
|  | 
 | ||||||
|  |   <script type="text/javascript"> | ||||||
|  |     document.addEventListener('DOMContentLoaded', function() { | ||||||
|  |     	let cardToggles = document.getElementsByClassName('card-toggle'); | ||||||
|  |     	for (let i = 0; i < cardToggles.length; i++) { | ||||||
|  |     		cardToggles[i].addEventListener('click', e => { | ||||||
|  |     			e.currentTarget.parentElement.parentElement.childNodes[3].classList.toggle('is-hidden'); | ||||||
|  |     		}); | ||||||
|  |     	} | ||||||
|  |     }); | ||||||
|  |   </script> | ||||||
|  | 
 | ||||||
|  |   <style> | ||||||
|  |     body { | ||||||
|  |       height: 100vh; | ||||||
|  |     } | ||||||
|  |     .error-box { | ||||||
|  |       margin: 1.25rem auto; | ||||||
|  |       max-width: 600px; | ||||||
|  |     } | ||||||
|  |     .status-code { | ||||||
|  |       font-size: 12rem; | ||||||
|  |       font-weight: 600; | ||||||
|  |     } | ||||||
|  |     #more-info.card { | ||||||
|  |       border: 1px solid #f0f0f0; | ||||||
|  |     } | ||||||
|  |     footer a { | ||||||
|  |       text-decoration: underline; | ||||||
|  |     } | ||||||
|  |   </style> | ||||||
|  | </head> | ||||||
|  | <body class="has-background-light"> | ||||||
|  |   <section class="section"> | ||||||
|  |     <div class="box block error-box has-text-centered"> | ||||||
|  |       <div class="status-code">{{.StatusCode}}</div> | ||||||
|  |       <div class="block"> | ||||||
|  |         <h1 class="subtitle is-1">{{.Title}}</h1> | ||||||
|  |       </div> | ||||||
|  | 
 | ||||||
|  |       {{ if .Message }} | ||||||
|  |       <div id="more-info" class="block card is-fullwidth is-shadowless"> | ||||||
|  |   			<header class="card-header is-shadowless"> | ||||||
|  |   				<p class="card-header-title">More Info</p> | ||||||
|  |   				<a class="card-header-icon card-toggle"> | ||||||
|  |   					<i class="fa fa-angle-down"></i> | ||||||
|  |   				</a> | ||||||
|  |   			</header> | ||||||
|  |   			<div class="card-content has-text-left is-hidden"> | ||||||
|  |   				<div class="content"> | ||||||
|  |   					{{.Message}} | ||||||
|  |   				</div> | ||||||
|  |   			</div> | ||||||
|  |   		</div> | ||||||
|  |       {{ end }} | ||||||
|  | 
 | ||||||
|  |       {{ if .Redirect }} | ||||||
|  |       <hr> | ||||||
|  | 
 | ||||||
|  |       <div class="columns"> | ||||||
|  |         <div class="column"> | ||||||
|  |           <form method="GET" action="{{.Redirect}}"> | ||||||
|  |             <button type="submit" class="button is-danger is-fullwidth">Go back</button> | ||||||
|  |           </form> | ||||||
|  |         </div> | ||||||
|  |         <div class="column"> | ||||||
|  |           <form method="GET" action="{{.ProxyPrefix}}/sign_in"> | ||||||
|  |             <input type="hidden" name="rd" value="{{.Redirect}}"> | ||||||
|  |             <button type="submit" class="button is-primary is-fullwidth">Sign in</button> | ||||||
|  |           </form> | ||||||
|  |         </div> | ||||||
|  |       </div> | ||||||
|  |       {{ end }} | ||||||
|  | 
 | ||||||
|  |     </div> | ||||||
|  |   </section> | ||||||
|  | 
 | ||||||
|  |   <footer class="footer has-text-grey has-background-light is-size-7"> | ||||||
|  |     <div class="content has-text-centered"> | ||||||
|  |     	{{ if eq .Footer "-" }} | ||||||
|  |     	{{ else if eq .Footer ""}} | ||||||
|  |     	<p>Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy" class="has-text-grey">OAuth2 Proxy</a> version {{.Version}}</p> | ||||||
|  |     	{{ else }} | ||||||
|  |     	<p>{{.Footer}}</p> | ||||||
|  |     	{{ end }} | ||||||
|  |     </div> | ||||||
|  | 	</footer> | ||||||
|  | 
 | ||||||
|  |   </body> | ||||||
|  | </html> | ||||||
|  | {{end}}` | ||||||
|  | 
 | ||||||
|  | 	defaultSignInTemplate = `{{define "sign_in.html"}} | ||||||
|  | <!DOCTYPE html> | ||||||
|  | <html lang="en" charset="utf-8"> | ||||||
|  |   <head> | ||||||
|  |     <meta charset="utf-8"> | ||||||
|  |     <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no"> | ||||||
|  |     <title>Sign In</title> | ||||||
|  |     <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.1/css/bulma.min.css"> | ||||||
|  | 
 | ||||||
|  |     <style> | ||||||
|  |       body { | ||||||
|  |         height: 100vh; | ||||||
|  |       } | ||||||
|  |       .sign-in-box { | ||||||
|  |         max-width: 400px; | ||||||
|  |         margin: 1.25rem auto; | ||||||
|  |       } | ||||||
|  |       footer a { | ||||||
|  |         text-decoration: underline; | ||||||
|  |       } | ||||||
|  |     </style> | ||||||
|  | 
 | ||||||
|  |     <script> | ||||||
|  |       if (window.location.hash) { | ||||||
|  |         (function() { | ||||||
|  |           var inputs = document.getElementsByName('rd'); | ||||||
|  |           for (var i = 0; i < inputs.length; i++) { | ||||||
|  |             // Add hash, but make sure it is only added once
 | ||||||
|  |             var idx = inputs[i].value.indexOf('#'); | ||||||
|  |             if (idx >= 0) { | ||||||
|  |               // Remove existing hash from URL
 | ||||||
|  |               inputs[i].value = inputs[i].value.substr(0, idx); | ||||||
|  |             } | ||||||
|  |             inputs[i].value += window.location.hash; | ||||||
|  |           } | ||||||
|  |         })(); | ||||||
|  |       } | ||||||
|  |     </script> | ||||||
|  |   </head> | ||||||
|  |   <body class="has-background-light"> | ||||||
|  |   <section class="section"> | ||||||
|  |     <div class="box block sign-in-box has-text-centered"> | ||||||
|  |       <form method="GET" action="{{.ProxyPrefix}}/start"> | ||||||
|  |         <input type="hidden" name="rd" value="{{.Redirect}}"> | ||||||
|  |           {{ if .SignInMessage }} | ||||||
|  |           <p class="block">{{.SignInMessage}}</p> | ||||||
|  |           {{ end}} | ||||||
|  |           <button type="submit" class="button block is-primary">Sign in with {{.ProviderName}}</button> | ||||||
|  |       </form> | ||||||
|  | 
 | ||||||
|  |       {{ if .CustomLogin }} | ||||||
|  |       <hr> | ||||||
|  | 
 | ||||||
|  |       <form method="POST" action="{{.ProxyPrefix}}/sign_in" class="block"> | ||||||
|  |         <input type="hidden" name="rd" value="{{.Redirect}}"> | ||||||
|  | 
 | ||||||
|  |         <div class="field"> | ||||||
|  |           <label class="label" for="username">Username</label> | ||||||
|  |           <div class="control"> | ||||||
|  |             <input class="input" type="email" placeholder="e.g. userx@example.com"  name="username" id="username"> | ||||||
|  |           </div> | ||||||
|  |         </div> | ||||||
|  | 
 | ||||||
|  |         <div class="field"> | ||||||
|  |           <label class="label" for="password">Password</label> | ||||||
|  |           <div class="control"> | ||||||
|  |             <input class="input" type="password" placeholder="********" name="password" id="password"> | ||||||
|  |           </div> | ||||||
|  |         </div> | ||||||
|  |         <button class="button is-primary">Sign in</button> | ||||||
|  |         {{ end }} | ||||||
|  |     </form> | ||||||
|  |     </div> | ||||||
|  |   </section> | ||||||
|  | 
 | ||||||
|  |   <footer class="footer has-text-grey has-background-light is-size-7"> | ||||||
|  |     <div class="content has-text-centered"> | ||||||
|  |     	{{ if eq .Footer "-" }} | ||||||
|  |     	{{ else if eq .Footer ""}} | ||||||
|  |     	<p>Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy" class="has-text-grey">OAuth2 Proxy</a> version {{.Version}}</p> | ||||||
|  |     	{{ else }} | ||||||
|  |     	<p>{{.Footer}}</p> | ||||||
|  |     	{{ end }} | ||||||
|  |     </div> | ||||||
|  | 	</footer> | ||||||
|  | 
 | ||||||
|  |   </body> | ||||||
|  | </html> | ||||||
|  | {{end}}` | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // loadTemplates adds the Sign In and Error templates from the custom template
 | ||||||
|  | // directory, or uses the defaults if they do not exist or the custom directory
 | ||||||
|  | // is not provided.
 | ||||||
|  | func loadTemplates(customDir string) (*template.Template, error) { | ||||||
|  | 	t := template.New("").Funcs(template.FuncMap{ | ||||||
|  | 		"ToUpper": strings.ToUpper, | ||||||
|  | 		"ToLower": strings.ToLower, | ||||||
|  | 	}) | ||||||
|  | 	var err error | ||||||
|  | 	t, err = addTemplate(t, customDir, signInTemplateName, defaultSignInTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not add Sign In template: %v", err) | ||||||
|  | 	} | ||||||
|  | 	t, err = addTemplate(t, customDir, errorTemplateName, defaultErrorTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not add Error template: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return t, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // addTemplate will add the template from the custom directory if provided,
 | ||||||
|  | // else it will add the default template.
 | ||||||
|  | func addTemplate(t *template.Template, customDir, fileName, defaultTemplate string) (*template.Template, error) { | ||||||
|  | 	filePath := filepath.Join(customDir, fileName) | ||||||
|  | 	if customDir != "" && isFile(filePath) { | ||||||
|  | 		t, err := t.ParseFiles(filePath) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("failed to parse template %s: %v", filePath, err) | ||||||
|  | 		} | ||||||
|  | 		return t, nil | ||||||
|  | 	} | ||||||
|  | 	t, err := t.Parse(defaultTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// This should not happen.
 | ||||||
|  | 		// Default templates should be tested and so should never fail to parse.
 | ||||||
|  | 		logger.Panic("Could not parse defaultTemplate: ", err) | ||||||
|  | 	} | ||||||
|  | 	return t, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // isFile checks if the file exists and checks whether it is a regular file.
 | ||||||
|  | // If either of these fail then it cannot be used as a template file.
 | ||||||
|  | func isFile(fileName string) bool { | ||||||
|  | 	info, err := os.Stat(fileName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Could not load file %s: %v, will use default template", fileName, err) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return info.Mode().IsRegular() | ||||||
|  | } | ||||||
|  | @ -0,0 +1,199 @@ | ||||||
|  | package pagewriter | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"html/template" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Templates", func() { | ||||||
|  | 	var customDir string | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		var err error | ||||||
|  | 		customDir, err = ioutil.TempDir("", "oauth2-proxy-templates-test") | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}` | ||||||
|  | 		signInFile := filepath.Join(customDir, signInTemplateName) | ||||||
|  | 		Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0600)).To(Succeed()) | ||||||
|  | 		errorFile := filepath.Join(customDir, errorTemplateName) | ||||||
|  | 		Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0600)).To(Succeed()) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	AfterEach(func() { | ||||||
|  | 		Expect(os.RemoveAll(customDir)).To(Succeed()) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("loadTemplates", func() { | ||||||
|  | 		var data interface{} | ||||||
|  | 		var t *template.Template | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			data = struct { | ||||||
|  | 				// For default templates
 | ||||||
|  | 				ProxyPrefix string | ||||||
|  | 				Redirect    string | ||||||
|  | 				Footer      string | ||||||
|  | 
 | ||||||
|  | 				// For default sign_in template
 | ||||||
|  | 				SignInMessage string | ||||||
|  | 				ProviderName  string | ||||||
|  | 				CustomLogin   bool | ||||||
|  | 
 | ||||||
|  | 				// For default error template
 | ||||||
|  | 				StatusCode int | ||||||
|  | 				Title      string | ||||||
|  | 				Message    string | ||||||
|  | 
 | ||||||
|  | 				// For custom templates
 | ||||||
|  | 				TestString string | ||||||
|  | 			}{ | ||||||
|  | 				ProxyPrefix: "<proxy-prefix>", | ||||||
|  | 				Redirect:    "<redirect>", | ||||||
|  | 				Footer:      "<footer>", | ||||||
|  | 
 | ||||||
|  | 				SignInMessage: "<sign-in-message>", | ||||||
|  | 				ProviderName:  "<provider-name>", | ||||||
|  | 				CustomLogin:   false, | ||||||
|  | 
 | ||||||
|  | 				StatusCode: 404, | ||||||
|  | 				Title:      "<title>", | ||||||
|  | 				Message:    "<message>", | ||||||
|  | 
 | ||||||
|  | 				TestString: "Testing", | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("With no custom directory", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				t, err = loadTemplates("") | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Use the default sign_in page", func() { | ||||||
|  | 				buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 				Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 				Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Use the default error page", func() { | ||||||
|  | 				buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 				Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 				Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("With a custom directory", func() { | ||||||
|  | 			Context("With both templates", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					var err error | ||||||
|  | 					t, err = loadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the custom sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the custom error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("With no error template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					Expect(os.Remove(filepath.Join(customDir, errorTemplateName))).To(Succeed()) | ||||||
|  | 
 | ||||||
|  | 					var err error | ||||||
|  | 					t, err = loadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the custom sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the default error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("With no sign_in template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					Expect(os.Remove(filepath.Join(customDir, signInTemplateName))).To(Succeed()) | ||||||
|  | 
 | ||||||
|  | 					var err error | ||||||
|  | 					t, err = loadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the default sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the custom error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("With an invalid sign_in template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					signInFile := filepath.Join(customDir, signInTemplateName) | ||||||
|  | 					Expect(ioutil.WriteFile(signInFile, []byte("{{"), 0600)) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Should return an error when loading templates", func() { | ||||||
|  | 					t, err := loadTemplates(customDir) | ||||||
|  | 					Expect(err).To(MatchError(HavePrefix("could not add Sign In template:"))) | ||||||
|  | 					Expect(t).To(BeNil()) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("With an invalid error template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					errorFile := filepath.Join(customDir, errorTemplateName) | ||||||
|  | 					Expect(ioutil.WriteFile(errorFile, []byte("{{"), 0600)) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Should return an error when loading templates", func() { | ||||||
|  | 					t, err := loadTemplates(customDir) | ||||||
|  | 					Expect(err).To(MatchError(HavePrefix("could not add Error template:"))) | ||||||
|  | 					Expect(t).To(BeNil()) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("isFile", func() { | ||||||
|  | 		It("with a valid file", func() { | ||||||
|  | 			Expect(isFile(filepath.Join(customDir, signInTemplateName))).To(BeTrue()) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("with a directory", func() { | ||||||
|  | 			Expect(isFile(customDir)).To(BeFalse()) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("with an invalid file", func() { | ||||||
|  | 			Expect(isFile(filepath.Join(customDir, "does_not_exist.html"))).To(BeFalse()) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,122 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
|  | 	"github.com/prometheus/client_golang/prometheus" | ||||||
|  | 	"github.com/prometheus/client_golang/prometheus/promhttp" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // DefaultMetricsHandler is the default http.Handler for serving metrics from
 | ||||||
|  | // the default prometheus.Registry
 | ||||||
|  | var DefaultMetricsHandler = NewMetricsHandlerWithDefaultRegistry() | ||||||
|  | 
 | ||||||
|  | // NewMetricsHandlerWithDefaultRegistry creates a new http.Handler for serving
 | ||||||
|  | // metrics from the default prometheus.Registry.
 | ||||||
|  | func NewMetricsHandlerWithDefaultRegistry() http.Handler { | ||||||
|  | 	return NewMetricsHandler(prometheus.DefaultRegisterer, prometheus.DefaultGatherer) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewMetricsHandler creates a new http.Handler for serving metrics from the
 | ||||||
|  | // provided prometheus.Registerer and prometheus.Gatherer
 | ||||||
|  | func NewMetricsHandler(registerer prometheus.Registerer, gatherer prometheus.Gatherer) http.Handler { | ||||||
|  | 	return promhttp.InstrumentMetricHandler( | ||||||
|  | 		registerer, promhttp.HandlerFor(gatherer, promhttp.HandlerOpts{}), | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewRequestMetricsWithDefaultRegistry returns a middleware that will record
 | ||||||
|  | // metrics for HTTP requests to the default prometheus.Registry
 | ||||||
|  | func NewRequestMetricsWithDefaultRegistry() alice.Constructor { | ||||||
|  | 	return NewRequestMetrics(prometheus.DefaultRegisterer) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewRequestMetrics returns a middleware that will record metrics for HTTP
 | ||||||
|  | // requests to the provided prometheus.Registerer
 | ||||||
|  | func NewRequestMetrics(registerer prometheus.Registerer) alice.Constructor { | ||||||
|  | 	return func(next http.Handler) http.Handler { | ||||||
|  | 		// Counter for all requests
 | ||||||
|  | 		// This is bucketed based on the response code we set
 | ||||||
|  | 		counterHandler := func(next http.Handler) http.Handler { | ||||||
|  | 			return promhttp.InstrumentHandlerCounter(registerRequestsCounter(registerer), next) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Gauge to all requests currently being handled
 | ||||||
|  | 		inFlightHandler := func(next http.Handler) http.Handler { | ||||||
|  | 			return promhttp.InstrumentHandlerInFlight(registerInflightRequestsGauge(registerer), next) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// The latency of all requests bucketed by HTTP method
 | ||||||
|  | 		durationHandler := func(next http.Handler) http.Handler { | ||||||
|  | 			return promhttp.InstrumentHandlerDuration(registerRequestsLatencyHistogram(registerer), next) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return alice.New(counterHandler, inFlightHandler, durationHandler).Then(next) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // registerRequestsCounter registers the 'oauth2_proxy_requests_total' metric
 | ||||||
|  | // This keeps a tally of all received requests bucket by their HTTP response
 | ||||||
|  | // status code
 | ||||||
|  | func registerRequestsCounter(registerer prometheus.Registerer) *prometheus.CounterVec { | ||||||
|  | 	counter := prometheus.NewCounterVec( | ||||||
|  | 		prometheus.CounterOpts{ | ||||||
|  | 			Name: "oauth2_proxy_requests_total", | ||||||
|  | 			Help: "Total number of requests by HTTP status code.", | ||||||
|  | 		}, | ||||||
|  | 		[]string{"code"}, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if err := registerer.Register(counter); err != nil { | ||||||
|  | 		if are, ok := err.(prometheus.AlreadyRegisteredError); ok { | ||||||
|  | 			counter = are.ExistingCollector.(*prometheus.CounterVec) | ||||||
|  | 		} else { | ||||||
|  | 			panic(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return counter | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // registerInflightRequestsGauge registers 'oauth2_proxy_requests_in_flight'
 | ||||||
|  | // This only keeps the count of currently in progress HTTP requests
 | ||||||
|  | func registerInflightRequestsGauge(registerer prometheus.Registerer) prometheus.Gauge { | ||||||
|  | 	gauge := prometheus.NewGauge(prometheus.GaugeOpts{ | ||||||
|  | 		Name: "oauth2_proxy_requests_in_flight", | ||||||
|  | 		Help: "Current number of requests being served.", | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if err := registerer.Register(gauge); err != nil { | ||||||
|  | 		if are, ok := err.(prometheus.AlreadyRegisteredError); ok { | ||||||
|  | 			gauge = are.ExistingCollector.(prometheus.Gauge) | ||||||
|  | 		} else { | ||||||
|  | 			panic(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return gauge | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // registerRequestsLatencyHistogram registers 'oauth2_proxy_response_duration_seconds'
 | ||||||
|  | // This keeps tally of the requests bucketed by the time taken to process the request
 | ||||||
|  | func registerRequestsLatencyHistogram(registerer prometheus.Registerer) *prometheus.HistogramVec { | ||||||
|  | 	histogram := prometheus.NewHistogramVec( | ||||||
|  | 		prometheus.HistogramOpts{ | ||||||
|  | 			Name:    "oauth2_proxy_response_duration_seconds", | ||||||
|  | 			Help:    "A histogram of request latencies.", | ||||||
|  | 			Buckets: prometheus.DefBuckets, | ||||||
|  | 		}, | ||||||
|  | 		[]string{"method"}, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if err := registerer.Register(histogram); err != nil { | ||||||
|  | 		if are, ok := err.(prometheus.AlreadyRegisteredError); ok { | ||||||
|  | 			histogram = are.ExistingCollector.(*prometheus.HistogramVec) | ||||||
|  | 		} else { | ||||||
|  | 			panic(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return histogram | ||||||
|  | } | ||||||
|  | @ -0,0 +1,67 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"os" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | 
 | ||||||
|  | 	"github.com/prometheus/client_golang/prometheus" | ||||||
|  | 	"github.com/prometheus/client_golang/prometheus/testutil" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Instrumentation suite", func() { | ||||||
|  | 	type requestTableInput struct { | ||||||
|  | 		registry        *prometheus.Registry | ||||||
|  | 		requestString   string | ||||||
|  | 		expectedHandler http.Handler | ||||||
|  | 		expectedMetrics []string | ||||||
|  | 		expectedStatus  int | ||||||
|  | 		// Prometheus output is large so is stored in testdata
 | ||||||
|  | 		expectedResultsFile string | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DescribeTable("when serving a request", | ||||||
|  | 		func(in *requestTableInput) { | ||||||
|  | 			req := httptest.NewRequest("", in.requestString, nil) | ||||||
|  | 
 | ||||||
|  | 			rw := httptest.NewRecorder() | ||||||
|  | 
 | ||||||
|  | 			handler := NewRequestMetrics(in.registry)(in.expectedHandler) | ||||||
|  | 			handler.ServeHTTP(rw, req) | ||||||
|  | 
 | ||||||
|  | 			Expect(rw.Code).To(Equal(in.expectedStatus)) | ||||||
|  | 
 | ||||||
|  | 			expectedPrometheusText, err := os.Open(in.expectedResultsFile) | ||||||
|  | 			Expect(err).NotTo(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			err = testutil.GatherAndCompare(in.registry, expectedPrometheusText, in.expectedMetrics...) | ||||||
|  | 			Expect(err).NotTo(HaveOccurred()) | ||||||
|  | 		}, | ||||||
|  | 		Entry("successfully", func() *requestTableInput { | ||||||
|  | 			in := &requestTableInput{ | ||||||
|  | 				registry:      prometheus.NewRegistry(), | ||||||
|  | 				requestString: "http://example.com/metrics", | ||||||
|  | 				expectedMetrics: []string{ | ||||||
|  | 					"oauth2_proxy_requests_total", | ||||||
|  | 				}, | ||||||
|  | 				expectedStatus:      200, | ||||||
|  | 				expectedResultsFile: "testdata/metrics/successfulrequest.txt", | ||||||
|  | 			} | ||||||
|  | 			in.expectedHandler = NewMetricsHandler(in.registry, in.registry) | ||||||
|  | 
 | ||||||
|  | 			return in | ||||||
|  | 		}()), | ||||||
|  | 		Entry("with not found", &requestTableInput{ | ||||||
|  | 			registry:            prometheus.NewRegistry(), | ||||||
|  | 			requestString:       "http://example.com/", | ||||||
|  | 			expectedHandler:     http.NotFoundHandler(), | ||||||
|  | 			expectedMetrics:     []string{"oauth2_proxy_requests_total"}, | ||||||
|  | 			expectedStatus:      404, | ||||||
|  | 			expectedResultsFile: "testdata/metrics/notfoundrequest.txt", | ||||||
|  | 		}), | ||||||
|  | 	) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,3 @@ | ||||||
|  | # HELP oauth2_proxy_requests_total Total number of requests by HTTP status code. | ||||||
|  | # TYPE oauth2_proxy_requests_total counter | ||||||
|  | oauth2_proxy_requests_total{code="404"} 1 | ||||||
|  | @ -0,0 +1,3 @@ | ||||||
|  | # HELP oauth2_proxy_requests_total Total number of requests by HTTP status code. | ||||||
|  | # TYPE oauth2_proxy_requests_total counter | ||||||
|  | oauth2_proxy_requests_total{code="200"} 1 | ||||||
|  | @ -85,7 +85,7 @@ func (m *Manager) ReleaseLock(req *http.Request) error { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return tckt.releaseSession(func(key string) error { | 	return tckt.releaseSessionLock(func(key string) error { | ||||||
| 		return m.Store.ReleaseLock(req.Context(), key) | 		return m.Store.ReleaseLock(req.Context(), key) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -26,9 +26,9 @@ type saveFunc func(string, []byte, time.Duration) error | ||||||
| // string key and returning the stored value as []byte
 | // string key and returning the stored value as []byte
 | ||||||
| type loadFunc func(string) ([]byte, error) | type loadFunc func(string) ([]byte, error) | ||||||
| 
 | 
 | ||||||
| // lockFunc performs a lock on a persistent store using a
 | // releaseLockFunc performs a lock or releases a lock on a persistent store using a
 | ||||||
| // string key
 | // string key
 | ||||||
| type lockFunc func(string) error | type releaseLockFunc func(string) error | ||||||
| 
 | 
 | ||||||
| // clearFunc performs a persistent store's clear functionality using
 | // clearFunc performs a persistent store's clear functionality using
 | ||||||
| // a string key for the target of the deletion.
 | // a string key for the target of the deletion.
 | ||||||
|  | @ -139,9 +139,9 @@ func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) { | ||||||
| 	return sessions.DecodeSessionState(ciphertext, c, false) | 	return sessions.DecodeSessionState(ciphertext, c, false) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // releaseSession releases a potential locked session
 | // releaseSessionLock releases a potential locked session
 | ||||||
| func (t *ticket) releaseSession(loader lockFunc) error { | func (t *ticket) releaseSessionLock(releaseLock releaseLockFunc) error { | ||||||
| 	err := loader(t.id) | 	err := releaseLock(t.id) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to release session state with the ticket: %v", err) | 		return fmt.Errorf("failed to release session state with the ticket: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue