diff --git a/CHANGELOG.md b/CHANGELOG.md index 1df28499..a7a499f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.0.1 +- [#1028](https://github.com/oauth2-proxy/oauth2-proxy/pull/1028) Refactor templates, update theme and provide styled error pages (@JoelSpeed) - [#1039](https://github.com/oauth2-proxy/oauth2-proxy/pull/1039) Ensure errors in tests are logged to the GinkgoWriter (@JoelSpeed) # V7.0.1 diff --git a/oauthproxy.go b/oauthproxy.go index 0cfa1f93..0ef5060c 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -18,6 +18,7 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" @@ -116,7 +117,10 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, fmt.Errorf("error initialising session store: %v", err) } - templates := loadTemplates(opts.CustomTemplatesDir) + templates, err := app.LoadTemplates(opts.Templates.Path) + if err != nil { + return nil, fmt.Errorf("error loading templates: %v", err) + } proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) if err != nil { @@ -211,12 +215,12 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr SkipProviderButton: opts.SkipProviderButton, templates: templates, trustedIPs: trustedIPs, - Banner: opts.Banner, - Footer: opts.Footer, + Banner: opts.Templates.Banner, + Footer: opts.Templates.Footer, SignInMessage: buildSignInMessage(opts), basicAuthValidator: basicAuthValidator, - displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, + displayHtpasswdForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, sessionChain: sessionChain, headersChain: headersChain, preAuthChain: preAuthChain, @@ -301,11 +305,11 @@ func buildHeadersChain(opts *options.Options) (alice.Chain, error) { func buildSignInMessage(opts *options.Options) string { var msg string - if len(opts.Banner) >= 1 { - if opts.Banner == "-" { + if len(opts.Templates.Banner) >= 1 { + if opts.Templates.Banner == "-" { msg = "" } else { - msg = opts.Banner + msg = opts.Templates.Banner } } else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { if len(opts.EmailDomains) > 1 { @@ -478,7 +482,7 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { switch path := req.URL.Path; { case path == p.RobotsPath: - p.RobotsTxt(rw) + p.RobotsTxt(rw, req) case p.IsAllowedRequest(req): p.SkipAuthProxy(rw, req) case path == p.SignInPath: @@ -499,30 +503,49 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { } // RobotsTxt disallows scraping pages from the OAuthProxy -func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { +func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter, req *http.Request) { _, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") if err != nil { logger.Printf("Error writing robots.txt: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } rw.WriteHeader(http.StatusOK) } // ErrorPage writes an error response -func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { +func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) { + redirectURL, err := p.getAppRedirect(req) + if err != nil { + logger.Errorf("Error obtaining redirect: %v", err) + } + if redirectURL == p.SignInPath || redirectURL == "" { + redirectURL = "/" + } + rw.WriteHeader(code) + + // We allow unescaped template.HTML since it is user configured options + /* #nosec G203 */ t := struct { Title string Message string ProxyPrefix string + StatusCode int + Redirect string + Footer template.HTML + Version string }{ - Title: fmt.Sprintf("%d %s", code, title), + Title: title, Message: message, ProxyPrefix: p.ProxyPrefix, + StatusCode: code, + Redirect: redirectURL, + Footer: template.HTML(p.Footer), + Version: VERSION, } - err := p.templates.ExecuteTemplate(rw, "error.html", t) - if err != nil { + + if err := p.templates.ExecuteTemplate(rw, "error.html", t); err != nil { logger.Printf("Error rendering error.html template: %v", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) } @@ -570,7 +593,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code err := p.ClearSessionCookie(rw, req) if err != nil { logger.Printf("Error clearing session cookie: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } rw.WriteHeader(code) @@ -578,7 +601,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code redirectURL, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } @@ -611,7 +634,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code err = p.templates.ExecuteTemplate(rw, "sign_in.html", t) if err != nil { logger.Printf("Error rendering sign_in.html template: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) } } @@ -639,7 +662,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } @@ -649,7 +672,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { err = p.SaveSession(rw, req, session) if err != nil { logger.Printf("Error saving session: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } http.Redirect(rw, req, redirect, http.StatusFound) @@ -688,7 +711,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { err = json.NewEncoder(rw).Encode(userInfo) if err != nil { logger.Printf("Error encoding user info: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) } } @@ -697,13 +720,13 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } err = p.ClearSessionCookie(rw, req) if err != nil { logger.Errorf("Error clearing session cookie: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } http.Redirect(rw, req, redirect, http.StatusFound) @@ -715,14 +738,14 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { nonce, err := encryption.Nonce() if err != nil { logger.Errorf("Error obtaining nonce: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } p.SetCSRFCookie(rw, req, nonce) redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } redirectURI := p.getOAuthRedirectURI(req) @@ -738,34 +761,34 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { err := req.ParseForm() if err != nil { logger.Errorf("Error while parsing OAuth2 callback: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } errorString := req.Form.Get("error") if errorString != "" { logger.Errorf("Error while parsing OAuth2 callback: %s", errorString) - p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", errorString) + p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", errorString) return } session, err := p.redeemCode(req) if err != nil { logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", "Internal Error") return } err = p.enrichSessionState(req.Context(), session) if err != nil { logger.Errorf("Error creating session during OAuth2 callback: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", "Internal Error") return } state := strings.SplitN(req.Form.Get("state"), ":", 2) if len(state) != 2 { logger.Error("Error while parsing OAuth2 state: invalid length") - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State") + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", "Invalid State") return } nonce := state[0] @@ -773,13 +796,13 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { c, err := req.Cookie(p.CSRFCookieName) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie") - p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", err.Error()) + p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", err.Error()) return } p.ClearCSRFCookie(rw, req) if c.Value != nonce { logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: CSRF token mismatch, potential attack") - p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed") + p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", "CSRF Failed") return } @@ -797,13 +820,13 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { err := p.SaveSession(rw, req, session) if err != nil { logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) - p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } http.Redirect(rw, req, redirect, http.StatusFound) } else { logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unauthorized") - p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "Invalid Account") + p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", "Invalid Account") } } @@ -885,12 +908,12 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { } case ErrAccessDenied: - p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized") + p.ErrorPage(rw, req, http.StatusUnauthorized, "Permission Denied", "Unauthorized") default: // unknown error logger.Errorf("Unexpected internal error: %v", err) - p.ErrorPage(rw, http.StatusInternalServerError, + p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", "Internal Error") } } diff --git a/pkg/apis/options/app.go b/pkg/apis/options/app.go new file mode 100644 index 00000000..1574ac97 --- /dev/null +++ b/pkg/apis/options/app.go @@ -0,0 +1,43 @@ +package options + +import "github.com/spf13/pflag" + +// Templates includes options for configuring the sign in and error pages +// appearance. +type Templates struct { + // Path is the path to a folder containing a sign_in.html and an error.html + // template. + // These files will be used instead of the default templates if present. + // If either file is missing, the default will be used instead. + Path string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` + + // Banner overides the default sign_in page banner text. If unspecified, + // the message will give users a list of allowed email domains. + Banner string `flag:"banner" cfg:"banner"` + + // Footer overrides the default sign_in page footer text. + Footer string `flag:"footer" cfg:"footer"` + + // DisplayLoginForm determines whether the sign_in page should render a + // password form if a static passwords file (htpasswd file) has been + // configured. + DisplayLoginForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` +} + +func templatesFlagSet() *pflag.FlagSet { + flagSet := pflag.NewFlagSet("templates", pflag.ExitOnError) + + flagSet.String("custom-templates-dir", "", "path to custom html templates") + flagSet.String("banner", "", "custom banner string. Use \"-\" to disable default banner.") + flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") + flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") + + return flagSet +} + +// templatesDefaults creates a Templates and populates it with any default values +func templatesDefaults() Templates { + return Templates{ + DisplayLoginForm: true, + } +} diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index c0f91422..9e8b6366 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -53,14 +53,11 @@ type Options struct { GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email"` GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json"` HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` - DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` - CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` - Banner string `flag:"banner" cfg:"banner"` - Footer string `flag:"footer" cfg:"footer"` - Cookie Cookie `cfg:",squash"` - Session SessionOptions `cfg:",squash"` - Logging Logging `cfg:",squash"` + Cookie Cookie `cfg:",squash"` + Session SessionOptions `cfg:",squash"` + Logging Logging `cfg:",squash"` + Templates Templates `cfg:",squash"` // Not used in the legacy config, name not allowed to match an external key (upstreams) // TODO(JoelSpeed): Rename when legacy config is removed @@ -135,16 +132,17 @@ func (o *Options) SetRealClientIPParser(s ipapi.RealClientIPParser) { o.realClie // NewOptions constructs a new Options with defaulted values func NewOptions() *Options { return &Options{ - ProxyPrefix: "/oauth2", - ProviderType: "google", - PingPath: "/ping", - HTTPAddress: "127.0.0.1:4180", - HTTPSAddress: ":443", - RealClientIPHeader: "X-Real-IP", - ForceHTTPS: false, - DisplayHtpasswdForm: true, + ProxyPrefix: "/oauth2", + ProviderType: "google", + PingPath: "/ping", + HTTPAddress: "127.0.0.1:4180", + HTTPSAddress: ":443", + RealClientIPHeader: "X-Real-IP", + ForceHTTPS: false, + Cookie: cookieDefaults(), Session: sessionOptionsDefaults(), + Templates: templatesDefaults(), AzureTenant: "common", SkipAuthPreflight: false, Prompt: "", // Change to "login" when ApprovalPrompt officially deprecated @@ -200,10 +198,6 @@ func NewFlagSet() *pflag.FlagSet { flagSet.String("client-secret-file", "", "the file with OAuth Client Secret") flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") flagSet.String("htpasswd-file", "", "additionally authenticate against a htpasswd file. Entries must be created with \"htpasswd -B\" for bcrypt encryption") - flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") - flagSet.String("custom-templates-dir", "", "path to custom html templates") - flagSet.String("banner", "", "custom banner string. Use \"-\" to disable default banner.") - flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. //sign_in)") flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") @@ -251,6 +245,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(loggingFlagSet()) + flagSet.AddFlagSet(templatesFlagSet()) return flagSet } diff --git a/pkg/app/app_suite_test.go b/pkg/app/app_suite_test.go new file mode 100644 index 00000000..d2df0233 --- /dev/null +++ b/pkg/app/app_suite_test.go @@ -0,0 +1,17 @@ +package app + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestOptionsSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + logger.SetErrOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "App Suite") +} diff --git a/pkg/app/templates.go b/pkg/app/templates.go new file mode 100644 index 00000000..ef38c902 --- /dev/null +++ b/pkg/app/templates.go @@ -0,0 +1,257 @@ +package app + +import ( + "fmt" + "html/template" + "os" + "path/filepath" + "strings" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" +) + +const ( + errorTemplateName = "error.html" + signInTemplateName = "sign_in.html" + + defaultErrorTemplate = `{{define "error.html"}} + + + + + + {{.StatusCode}} {{.Title}} + + + + + + + + +
+
+
{{.StatusCode}}
+
+

{{.Title}}

+
+ + {{ if .Message }} +
+
+

More Info

+ + + +
+ +
+ {{ end }} + +
+ +
+
+
+ +
+
+
+
+ + +
+
+
+ +
+
+ + + + + +{{end}}` + + defaultSignInTemplate = `{{define "sign_in.html"}} + + + + + + Sign In + + + + + + + +
+ +
+ + + + + +{{end}}` +) + +// LoadTemplates adds the Sign In and Error templates from the custom template +// directory, or uses the defaults if they do not exist or the custom directory +// is not provided. +func LoadTemplates(customDir string) (*template.Template, error) { + t := template.New("").Funcs(template.FuncMap{ + "ToUpper": strings.ToUpper, + "ToLower": strings.ToLower, + }) + var err error + t, err = addTemplate(t, customDir, signInTemplateName, defaultSignInTemplate) + if err != nil { + return nil, fmt.Errorf("could not add Sign In template: %v", err) + } + t, err = addTemplate(t, customDir, errorTemplateName, defaultErrorTemplate) + if err != nil { + return nil, fmt.Errorf("could not add Error template: %v", err) + } + + return t, nil +} + +// addTemplate will add the template from the custom directory if provided, +// else it will add the default template. +func addTemplate(t *template.Template, customDir, fileName, defaultTemplate string) (*template.Template, error) { + filePath := filepath.Join(customDir, fileName) + if customDir != "" && isFile(filePath) { + t, err := t.ParseFiles(filePath) + if err != nil { + return nil, fmt.Errorf("failed to parse template %s: %v", filePath, err) + } + return t, nil + } + t, err := t.Parse(defaultTemplate) + if err != nil { + // This should not happen. + // Default templates should be tested and so should never fail to parse. + logger.Panic("Could not parse defaultTemplate: ", err) + } + return t, nil +} + +// isFile checks if the file exists and checks whether it is a regular file. +// If either of these fail then it cannot be used as a template file. +func isFile(fileName string) bool { + info, err := os.Stat(fileName) + if err != nil { + logger.Errorf("Could not load file %s: %v, will use default template", fileName, err) + return false + } + return info.Mode().IsRegular() +} diff --git a/pkg/app/templates_test.go b/pkg/app/templates_test.go new file mode 100644 index 00000000..66f38b7f --- /dev/null +++ b/pkg/app/templates_test.go @@ -0,0 +1,199 @@ +package app + +import ( + "bytes" + "html/template" + "io/ioutil" + "os" + "path/filepath" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Templates", func() { + var customDir string + + BeforeEach(func() { + var err error + customDir, err = ioutil.TempDir("", "oauth2-proxy-templates-test") + Expect(err).ToNot(HaveOccurred()) + + templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}` + signInFile := filepath.Join(customDir, signInTemplateName) + Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0666)).To(Succeed()) + errorFile := filepath.Join(customDir, errorTemplateName) + Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0666)).To(Succeed()) + }) + + AfterEach(func() { + Expect(os.RemoveAll(customDir)).To(Succeed()) + }) + + Context("LoadTemplates", func() { + var data interface{} + var t *template.Template + + BeforeEach(func() { + data = struct { + // For default templates + ProxyPrefix string + Redirect string + Footer string + + // For default sign_in template + SignInMessage string + ProviderName string + CustomLogin bool + + // For default error template + StatusCode int + Title string + Message string + + // For custom templates + TestString string + }{ + ProxyPrefix: "", + Redirect: "", + Footer: "