Merge pull request #1028 from oauth2-proxy/templates
Refactor templates, update theme and provide styled error pages
This commit is contained in:
		
						commit
						9cdcd2b2d4
					
				|  | @ -8,6 +8,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v7.0.1 | ## Changes since v7.0.1 | ||||||
| 
 | 
 | ||||||
|  | - [#1028](https://github.com/oauth2-proxy/oauth2-proxy/pull/1028) Refactor templates, update theme and provide styled error pages (@JoelSpeed) | ||||||
| - [#1039](https://github.com/oauth2-proxy/oauth2-proxy/pull/1039) Ensure errors in tests are logged to the GinkgoWriter (@JoelSpeed) | - [#1039](https://github.com/oauth2-proxy/oauth2-proxy/pull/1039) Ensure errors in tests are logged to the GinkgoWriter (@JoelSpeed) | ||||||
| 
 | 
 | ||||||
| # V7.0.1 | # V7.0.1 | ||||||
|  |  | ||||||
|  | @ -18,6 +18,7 @@ import ( | ||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||||
|  | @ -116,7 +117,10 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		return nil, fmt.Errorf("error initialising session store: %v", err) | 		return nil, fmt.Errorf("error initialising session store: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	templates := loadTemplates(opts.CustomTemplatesDir) | 	templates, err := app.LoadTemplates(opts.Templates.Path) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error loading templates: %v", err) | ||||||
|  | 	} | ||||||
| 	proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) | 	proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) | ||||||
| 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) | 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -211,12 +215,12 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		SkipProviderButton:   opts.SkipProviderButton, | 		SkipProviderButton:   opts.SkipProviderButton, | ||||||
| 		templates:            templates, | 		templates:            templates, | ||||||
| 		trustedIPs:           trustedIPs, | 		trustedIPs:           trustedIPs, | ||||||
| 		Banner:               opts.Banner, | 		Banner:               opts.Templates.Banner, | ||||||
| 		Footer:               opts.Footer, | 		Footer:               opts.Templates.Footer, | ||||||
| 		SignInMessage:        buildSignInMessage(opts), | 		SignInMessage:        buildSignInMessage(opts), | ||||||
| 
 | 
 | ||||||
| 		basicAuthValidator:  basicAuthValidator, | 		basicAuthValidator:  basicAuthValidator, | ||||||
| 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | 		displayHtpasswdForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, | ||||||
| 		sessionChain:        sessionChain, | 		sessionChain:        sessionChain, | ||||||
| 		headersChain:        headersChain, | 		headersChain:        headersChain, | ||||||
| 		preAuthChain:        preAuthChain, | 		preAuthChain:        preAuthChain, | ||||||
|  | @ -301,11 +305,11 @@ func buildHeadersChain(opts *options.Options) (alice.Chain, error) { | ||||||
| 
 | 
 | ||||||
| func buildSignInMessage(opts *options.Options) string { | func buildSignInMessage(opts *options.Options) string { | ||||||
| 	var msg string | 	var msg string | ||||||
| 	if len(opts.Banner) >= 1 { | 	if len(opts.Templates.Banner) >= 1 { | ||||||
| 		if opts.Banner == "-" { | 		if opts.Templates.Banner == "-" { | ||||||
| 			msg = "" | 			msg = "" | ||||||
| 		} else { | 		} else { | ||||||
| 			msg = opts.Banner | 			msg = opts.Templates.Banner | ||||||
| 		} | 		} | ||||||
| 	} else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { | 	} else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { | ||||||
| 		if len(opts.EmailDomains) > 1 { | 		if len(opts.EmailDomains) > 1 { | ||||||
|  | @ -478,7 +482,7 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 
 | 
 | ||||||
| 	switch path := req.URL.Path; { | 	switch path := req.URL.Path; { | ||||||
| 	case path == p.RobotsPath: | 	case path == p.RobotsPath: | ||||||
| 		p.RobotsTxt(rw) | 		p.RobotsTxt(rw, req) | ||||||
| 	case p.IsAllowedRequest(req): | 	case p.IsAllowedRequest(req): | ||||||
| 		p.SkipAuthProxy(rw, req) | 		p.SkipAuthProxy(rw, req) | ||||||
| 	case path == p.SignInPath: | 	case path == p.SignInPath: | ||||||
|  | @ -499,30 +503,49 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RobotsTxt disallows scraping pages from the OAuthProxy
 | // RobotsTxt disallows scraping pages from the OAuthProxy
 | ||||||
| func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { | func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | 	_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("Error writing robots.txt: %v", err) | 		logger.Printf("Error writing robots.txt: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	rw.WriteHeader(http.StatusOK) | 	rw.WriteHeader(http.StatusOK) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ErrorPage writes an error response
 | // ErrorPage writes an error response
 | ||||||
| func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { | func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) { | ||||||
|  | 	redirectURL, err := p.getAppRedirect(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if redirectURL == p.SignInPath || redirectURL == "" { | ||||||
|  | 		redirectURL = "/" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
|  | 
 | ||||||
|  | 	// We allow unescaped template.HTML since it is user configured options
 | ||||||
|  | 	/* #nosec G203 */ | ||||||
| 	t := struct { | 	t := struct { | ||||||
| 		Title       string | 		Title       string | ||||||
| 		Message     string | 		Message     string | ||||||
| 		ProxyPrefix string | 		ProxyPrefix string | ||||||
|  | 		StatusCode  int | ||||||
|  | 		Redirect    string | ||||||
|  | 		Footer      template.HTML | ||||||
|  | 		Version     string | ||||||
| 	}{ | 	}{ | ||||||
| 		Title:       fmt.Sprintf("%d %s", code, title), | 		Title:       title, | ||||||
| 		Message:     message, | 		Message:     message, | ||||||
| 		ProxyPrefix: p.ProxyPrefix, | 		ProxyPrefix: p.ProxyPrefix, | ||||||
|  | 		StatusCode:  code, | ||||||
|  | 		Redirect:    redirectURL, | ||||||
|  | 		Footer:      template.HTML(p.Footer), | ||||||
|  | 		Version:     VERSION, | ||||||
| 	} | 	} | ||||||
| 	err := p.templates.ExecuteTemplate(rw, "error.html", t) | 
 | ||||||
| 	if err != nil { | 	if err := p.templates.ExecuteTemplate(rw, "error.html", t); err != nil { | ||||||
| 		logger.Printf("Error rendering error.html template: %v", err) | 		logger.Printf("Error rendering error.html template: %v", err) | ||||||
| 		http.Error(rw, "Internal Server Error", http.StatusInternalServerError) | 		http.Error(rw, "Internal Server Error", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|  | @ -570,7 +593,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 	err := p.ClearSessionCookie(rw, req) | 	err := p.ClearSessionCookie(rw, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("Error clearing session cookie: %v", err) | 		logger.Printf("Error clearing session cookie: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
|  | @ -578,7 +601,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 	redirectURL, err := p.getAppRedirect(req) | 	redirectURL, err := p.getAppRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -611,7 +634,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 	err = p.templates.ExecuteTemplate(rw, "sign_in.html", t) | 	err = p.templates.ExecuteTemplate(rw, "sign_in.html", t) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("Error rendering sign_in.html template: %v", err) | 		logger.Printf("Error rendering sign_in.html template: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -639,7 +662,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	redirect, err := p.getAppRedirect(req) | 	redirect, err := p.getAppRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -649,7 +672,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		err = p.SaveSession(rw, req, session) | 		err = p.SaveSession(rw, req, session) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Printf("Error saving session: %v", err) | 			logger.Printf("Error saving session: %v", err) | ||||||
| 			p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 			p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		http.Redirect(rw, req, redirect, http.StatusFound) | 		http.Redirect(rw, req, redirect, http.StatusFound) | ||||||
|  | @ -688,7 +711,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	err = json.NewEncoder(rw).Encode(userInfo) | 	err = json.NewEncoder(rw).Encode(userInfo) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("Error encoding user info: %v", err) | 		logger.Printf("Error encoding user info: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -697,13 +720,13 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	redirect, err := p.getAppRedirect(req) | 	redirect, err := p.getAppRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	err = p.ClearSessionCookie(rw, req) | 	err = p.ClearSessionCookie(rw, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error clearing session cookie: %v", err) | 		logger.Errorf("Error clearing session cookie: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	http.Redirect(rw, req, redirect, http.StatusFound) | 	http.Redirect(rw, req, redirect, http.StatusFound) | ||||||
|  | @ -715,14 +738,14 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	nonce, err := encryption.Nonce() | 	nonce, err := encryption.Nonce() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining nonce: %v", err) | 		logger.Errorf("Error obtaining nonce: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	p.SetCSRFCookie(rw, req, nonce) | 	p.SetCSRFCookie(rw, req, nonce) | ||||||
| 	redirect, err := p.getAppRedirect(req) | 	redirect, err := p.getAppRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	redirectURI := p.getOAuthRedirectURI(req) | 	redirectURI := p.getOAuthRedirectURI(req) | ||||||
|  | @ -738,34 +761,34 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	err := req.ParseForm() | 	err := req.ParseForm() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error while parsing OAuth2 callback: %v", err) | 		logger.Errorf("Error while parsing OAuth2 callback: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	errorString := req.Form.Get("error") | 	errorString := req.Form.Get("error") | ||||||
| 	if errorString != "" { | 	if errorString != "" { | ||||||
| 		logger.Errorf("Error while parsing OAuth2 callback: %s", errorString) | 		logger.Errorf("Error while parsing OAuth2 callback: %s", errorString) | ||||||
| 		p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", errorString) | 		p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", errorString) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	session, err := p.redeemCode(req) | 	session, err := p.redeemCode(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) | 		logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = p.enrichSessionState(req.Context(), session) | 	err = p.enrichSessionState(req.Context(), session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error creating session during OAuth2 callback: %v", err) | 		logger.Errorf("Error creating session during OAuth2 callback: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	state := strings.SplitN(req.Form.Get("state"), ":", 2) | 	state := strings.SplitN(req.Form.Get("state"), ":", 2) | ||||||
| 	if len(state) != 2 { | 	if len(state) != 2 { | ||||||
| 		logger.Error("Error while parsing OAuth2 state: invalid length") | 		logger.Error("Error while parsing OAuth2 state: invalid length") | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State") | 		p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", "Invalid State") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	nonce := state[0] | 	nonce := state[0] | ||||||
|  | @ -773,13 +796,13 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	c, err := req.Cookie(p.CSRFCookieName) | 	c, err := req.Cookie(p.CSRFCookieName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie") | 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie") | ||||||
| 		p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", err.Error()) | 		p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	p.ClearCSRFCookie(rw, req) | 	p.ClearCSRFCookie(rw, req) | ||||||
| 	if c.Value != nonce { | 	if c.Value != nonce { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: CSRF token mismatch, potential attack") | 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: CSRF token mismatch, potential attack") | ||||||
| 		p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed") | 		p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", "CSRF Failed") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -797,13 +820,13 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		err := p.SaveSession(rw, req, session) | 		err := p.SaveSession(rw, req, session) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) | 			logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) | ||||||
| 			p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 			p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		http.Redirect(rw, req, redirect, http.StatusFound) | 		http.Redirect(rw, req, redirect, http.StatusFound) | ||||||
| 	} else { | 	} else { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unauthorized") | 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unauthorized") | ||||||
| 		p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "Invalid Account") | 		p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", "Invalid Account") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -885,12 +908,12 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 	case ErrAccessDenied: | 	case ErrAccessDenied: | ||||||
| 		p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized") | 		p.ErrorPage(rw, req, http.StatusUnauthorized, "Permission Denied", "Unauthorized") | ||||||
| 
 | 
 | ||||||
| 	default: | 	default: | ||||||
| 		// unknown error
 | 		// unknown error
 | ||||||
| 		logger.Errorf("Unexpected internal error: %v", err) | 		logger.Errorf("Unexpected internal error: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, | 		p.ErrorPage(rw, req, http.StatusInternalServerError, | ||||||
| 			"Internal Error", "Internal Error") | 			"Internal Error", "Internal Error") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,43 @@ | ||||||
|  | package options | ||||||
|  | 
 | ||||||
|  | import "github.com/spf13/pflag" | ||||||
|  | 
 | ||||||
|  | // Templates includes options for configuring the sign in and error pages
 | ||||||
|  | // appearance.
 | ||||||
|  | type Templates struct { | ||||||
|  | 	// Path is the path to a folder containing a sign_in.html and an error.html
 | ||||||
|  | 	// template.
 | ||||||
|  | 	// These files will be used instead of the default templates if present.
 | ||||||
|  | 	// If either file is missing, the default will be used instead.
 | ||||||
|  | 	Path string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` | ||||||
|  | 
 | ||||||
|  | 	// Banner overides the default sign_in page banner text. If unspecified,
 | ||||||
|  | 	// the message will give users a list of allowed email domains.
 | ||||||
|  | 	Banner string `flag:"banner" cfg:"banner"` | ||||||
|  | 
 | ||||||
|  | 	// Footer overrides the default sign_in page footer text.
 | ||||||
|  | 	Footer string `flag:"footer" cfg:"footer"` | ||||||
|  | 
 | ||||||
|  | 	// DisplayLoginForm determines whether the sign_in page should render a
 | ||||||
|  | 	// password form if a static passwords file (htpasswd file) has been
 | ||||||
|  | 	// configured.
 | ||||||
|  | 	DisplayLoginForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func templatesFlagSet() *pflag.FlagSet { | ||||||
|  | 	flagSet := pflag.NewFlagSet("templates", pflag.ExitOnError) | ||||||
|  | 
 | ||||||
|  | 	flagSet.String("custom-templates-dir", "", "path to custom html templates") | ||||||
|  | 	flagSet.String("banner", "", "custom banner string. Use \"-\" to disable default banner.") | ||||||
|  | 	flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") | ||||||
|  | 	flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") | ||||||
|  | 
 | ||||||
|  | 	return flagSet | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // templatesDefaults creates a Templates and populates it with any default values
 | ||||||
|  | func templatesDefaults() Templates { | ||||||
|  | 	return Templates{ | ||||||
|  | 		DisplayLoginForm: true, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -53,14 +53,11 @@ type Options struct { | ||||||
| 	GoogleAdminEmail         string   `flag:"google-admin-email" cfg:"google_admin_email"` | 	GoogleAdminEmail         string   `flag:"google-admin-email" cfg:"google_admin_email"` | ||||||
| 	GoogleServiceAccountJSON string   `flag:"google-service-account-json" cfg:"google_service_account_json"` | 	GoogleServiceAccountJSON string   `flag:"google-service-account-json" cfg:"google_service_account_json"` | ||||||
| 	HtpasswdFile             string   `flag:"htpasswd-file" cfg:"htpasswd_file"` | 	HtpasswdFile             string   `flag:"htpasswd-file" cfg:"htpasswd_file"` | ||||||
| 	DisplayHtpasswdForm      bool     `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` |  | ||||||
| 	CustomTemplatesDir       string   `flag:"custom-templates-dir" cfg:"custom_templates_dir"` |  | ||||||
| 	Banner                   string   `flag:"banner" cfg:"banner"` |  | ||||||
| 	Footer                   string   `flag:"footer" cfg:"footer"` |  | ||||||
| 
 | 
 | ||||||
| 	Cookie    Cookie         `cfg:",squash"` | 	Cookie    Cookie         `cfg:",squash"` | ||||||
| 	Session   SessionOptions `cfg:",squash"` | 	Session   SessionOptions `cfg:",squash"` | ||||||
| 	Logging   Logging        `cfg:",squash"` | 	Logging   Logging        `cfg:",squash"` | ||||||
|  | 	Templates Templates      `cfg:",squash"` | ||||||
| 
 | 
 | ||||||
| 	// Not used in the legacy config, name not allowed to match an external key (upstreams)
 | 	// Not used in the legacy config, name not allowed to match an external key (upstreams)
 | ||||||
| 	// TODO(JoelSpeed): Rename when legacy config is removed
 | 	// TODO(JoelSpeed): Rename when legacy config is removed
 | ||||||
|  | @ -142,9 +139,10 @@ func NewOptions() *Options { | ||||||
| 		HTTPSAddress:       ":443", | 		HTTPSAddress:       ":443", | ||||||
| 		RealClientIPHeader: "X-Real-IP", | 		RealClientIPHeader: "X-Real-IP", | ||||||
| 		ForceHTTPS:         false, | 		ForceHTTPS:         false, | ||||||
| 		DisplayHtpasswdForm:              true, | 
 | ||||||
| 		Cookie:                           cookieDefaults(), | 		Cookie:                           cookieDefaults(), | ||||||
| 		Session:                          sessionOptionsDefaults(), | 		Session:                          sessionOptionsDefaults(), | ||||||
|  | 		Templates:                        templatesDefaults(), | ||||||
| 		AzureTenant:                      "common", | 		AzureTenant:                      "common", | ||||||
| 		SkipAuthPreflight:                false, | 		SkipAuthPreflight:                false, | ||||||
| 		Prompt:                           "", // Change to "login" when ApprovalPrompt officially deprecated
 | 		Prompt:                           "", // Change to "login" when ApprovalPrompt officially deprecated
 | ||||||
|  | @ -200,10 +198,6 @@ func NewFlagSet() *pflag.FlagSet { | ||||||
| 	flagSet.String("client-secret-file", "", "the file with OAuth Client Secret") | 	flagSet.String("client-secret-file", "", "the file with OAuth Client Secret") | ||||||
| 	flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") | 	flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") | ||||||
| 	flagSet.String("htpasswd-file", "", "additionally authenticate against a htpasswd file. Entries must be created with \"htpasswd -B\" for bcrypt encryption") | 	flagSet.String("htpasswd-file", "", "additionally authenticate against a htpasswd file. Entries must be created with \"htpasswd -B\" for bcrypt encryption") | ||||||
| 	flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") |  | ||||||
| 	flagSet.String("custom-templates-dir", "", "path to custom html templates") |  | ||||||
| 	flagSet.String("banner", "", "custom banner string. Use \"-\" to disable default banner.") |  | ||||||
| 	flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") |  | ||||||
| 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | ||||||
| 	flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") | 	flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") | ||||||
| 	flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") | 	flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") | ||||||
|  | @ -251,6 +245,7 @@ func NewFlagSet() *pflag.FlagSet { | ||||||
| 
 | 
 | ||||||
| 	flagSet.AddFlagSet(cookieFlagSet()) | 	flagSet.AddFlagSet(cookieFlagSet()) | ||||||
| 	flagSet.AddFlagSet(loggingFlagSet()) | 	flagSet.AddFlagSet(loggingFlagSet()) | ||||||
|  | 	flagSet.AddFlagSet(templatesFlagSet()) | ||||||
| 
 | 
 | ||||||
| 	return flagSet | 	return flagSet | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,17 @@ | ||||||
|  | package app | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestOptionsSuite(t *testing.T) { | ||||||
|  | 	logger.SetOutput(GinkgoWriter) | ||||||
|  | 	logger.SetErrOutput(GinkgoWriter) | ||||||
|  | 
 | ||||||
|  | 	RegisterFailHandler(Fail) | ||||||
|  | 	RunSpecs(t, "App Suite") | ||||||
|  | } | ||||||
|  | @ -0,0 +1,257 @@ | ||||||
|  | package app | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"html/template" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	errorTemplateName  = "error.html" | ||||||
|  | 	signInTemplateName = "sign_in.html" | ||||||
|  | 
 | ||||||
|  | 	defaultErrorTemplate = `{{define "error.html"}} | ||||||
|  | <!DOCTYPE html> | ||||||
|  | <html lang="en" charset="utf-8"> | ||||||
|  | <head> | ||||||
|  |   <meta charset="utf-8"> | ||||||
|  |   <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no"> | ||||||
|  |   	<title>{{.StatusCode}} {{.Title}}</title> | ||||||
|  |   <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.1/css/bulma.min.css"> | ||||||
|  |   <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.2/css/all.min.css"> | ||||||
|  | 
 | ||||||
|  |   <script type="text/javascript"> | ||||||
|  |     document.addEventListener('DOMContentLoaded', function() { | ||||||
|  |     	let cardToggles = document.getElementsByClassName('card-toggle'); | ||||||
|  |     	for (let i = 0; i < cardToggles.length; i++) { | ||||||
|  |     		cardToggles[i].addEventListener('click', e => { | ||||||
|  |     			e.currentTarget.parentElement.parentElement.childNodes[3].classList.toggle('is-hidden'); | ||||||
|  |     		}); | ||||||
|  |     	} | ||||||
|  |     }); | ||||||
|  |   </script> | ||||||
|  | 
 | ||||||
|  |   <style> | ||||||
|  |     body { | ||||||
|  |       height: 100vh; | ||||||
|  |     } | ||||||
|  |     .error-box { | ||||||
|  |       margin: 1.25rem auto; | ||||||
|  |       max-width: 600px; | ||||||
|  |     } | ||||||
|  |     .status-code { | ||||||
|  |       font-size: 12rem; | ||||||
|  |       font-weight: 600; | ||||||
|  |     } | ||||||
|  |     #more-info.card { | ||||||
|  |       border: 1px solid #f0f0f0; | ||||||
|  |     } | ||||||
|  |     footer a { | ||||||
|  |       text-decoration: underline; | ||||||
|  |     } | ||||||
|  |   </style> | ||||||
|  | </head> | ||||||
|  | <body class="has-background-light"> | ||||||
|  |   <section class="section"> | ||||||
|  |     <div class="box block error-box has-text-centered"> | ||||||
|  |       <div class="status-code">{{.StatusCode}}</div> | ||||||
|  |       <div class="block"> | ||||||
|  |         <h1 class="subtitle is-1">{{.Title}}</h1> | ||||||
|  |       </div> | ||||||
|  | 
 | ||||||
|  |       {{ if .Message }} | ||||||
|  |       <div id="more-info" class="block card is-fullwidth is-shadowless"> | ||||||
|  |   			<header class="card-header is-shadowless"> | ||||||
|  |   				<p class="card-header-title">More Info</p> | ||||||
|  |   				<a class="card-header-icon card-toggle"> | ||||||
|  |   					<i class="fa fa-angle-down"></i> | ||||||
|  |   				</a> | ||||||
|  |   			</header> | ||||||
|  |   			<div class="card-content has-text-left is-hidden"> | ||||||
|  |   				<div class="content"> | ||||||
|  |   					{{.Message}} | ||||||
|  |   				</div> | ||||||
|  |   			</div> | ||||||
|  |   		</div> | ||||||
|  |       {{ end }} | ||||||
|  | 
 | ||||||
|  |       <hr> | ||||||
|  | 
 | ||||||
|  |       <div class="columns"> | ||||||
|  |         <div class="column"> | ||||||
|  |           <form method="GET" action="{{.Redirect}}"> | ||||||
|  |             <button type="submit" class="button is-danger is-fullwidth">Go back</button> | ||||||
|  |           </form> | ||||||
|  |         </div> | ||||||
|  |         <div class="column"> | ||||||
|  |           <form method="GET" action="{{.ProxyPrefix}}/sign_in"> | ||||||
|  |             <input type="hidden" name="rd" value="{{.Redirect}}"> | ||||||
|  |             <button type="submit" class="button is-primary is-fullwidth">Sign in</button> | ||||||
|  |           </form> | ||||||
|  |         </div> | ||||||
|  |       </div> | ||||||
|  | 
 | ||||||
|  |     </div> | ||||||
|  |   </section> | ||||||
|  | 
 | ||||||
|  |   <footer class="footer has-text-grey has-background-light is-size-7"> | ||||||
|  |     <div class="content has-text-centered"> | ||||||
|  |     	{{ if eq .Footer "-" }} | ||||||
|  |     	{{ else if eq .Footer ""}} | ||||||
|  |     	<p>Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy" class="has-text-grey">OAuth2 Proxy</a> version {{.Version}}</p> | ||||||
|  |     	{{ else }} | ||||||
|  |     	<p>{{.Footer}}</p> | ||||||
|  |     	{{ end }} | ||||||
|  |     </div> | ||||||
|  | 	</footer> | ||||||
|  | 
 | ||||||
|  |   </body> | ||||||
|  | </html> | ||||||
|  | {{end}}` | ||||||
|  | 
 | ||||||
|  | 	defaultSignInTemplate = `{{define "sign_in.html"}} | ||||||
|  | <!DOCTYPE html> | ||||||
|  | <html lang="en" charset="utf-8"> | ||||||
|  |   <head> | ||||||
|  |     <meta charset="utf-8"> | ||||||
|  |     <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no"> | ||||||
|  |     <title>Sign In</title> | ||||||
|  |     <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.1/css/bulma.min.css"> | ||||||
|  | 
 | ||||||
|  |     <style> | ||||||
|  |       body { | ||||||
|  |         height: 100vh; | ||||||
|  |       } | ||||||
|  |       .sign-in-box { | ||||||
|  |         max-width: 400px; | ||||||
|  |         margin: 1.25rem auto; | ||||||
|  |       } | ||||||
|  |       footer a { | ||||||
|  |         text-decoration: underline; | ||||||
|  |       } | ||||||
|  |     </style> | ||||||
|  | 
 | ||||||
|  |     <script> | ||||||
|  |       if (window.location.hash) { | ||||||
|  |         (function() { | ||||||
|  |           var inputs = document.getElementsByName('rd'); | ||||||
|  |           for (var i = 0; i < inputs.length; i++) { | ||||||
|  |             // Add hash, but make sure it is only added once
 | ||||||
|  |             var idx = inputs[i].value.indexOf('#'); | ||||||
|  |             if (idx >= 0) { | ||||||
|  |               // Remove existing hash from URL
 | ||||||
|  |               inputs[i].value = inputs[i].value.substr(0, idx); | ||||||
|  |             } | ||||||
|  |             inputs[i].value += window.location.hash; | ||||||
|  |           } | ||||||
|  |         })(); | ||||||
|  |       } | ||||||
|  |     </script> | ||||||
|  |   </head> | ||||||
|  |   <body class="has-background-light"> | ||||||
|  |   <section class="section"> | ||||||
|  |     <div class="box block sign-in-box has-text-centered"> | ||||||
|  |       <form method="GET" action="{{.ProxyPrefix}}/start"> | ||||||
|  |         <input type="hidden" name="rd" value="{{.Redirect}}"> | ||||||
|  |           {{ if .SignInMessage }} | ||||||
|  |           <p class="block">{{.SignInMessage}}</p> | ||||||
|  |           {{ end}} | ||||||
|  |           <button type="submit" class="button block is-primary">Sign in with {{.ProviderName}}</button> | ||||||
|  |       </form> | ||||||
|  | 
 | ||||||
|  |       {{ if .CustomLogin }} | ||||||
|  |       <hr> | ||||||
|  | 
 | ||||||
|  |       <form method="POST" action="{{.ProxyPrefix}}/sign_in" class="block"> | ||||||
|  |         <input type="hidden" name="rd" value="{{.Redirect}}"> | ||||||
|  | 
 | ||||||
|  |         <div class="field"> | ||||||
|  |           <label class="label" for="username">Username</label> | ||||||
|  |           <div class="control"> | ||||||
|  |             <input class="input" type="email" placeholder="e.g. userx@example.com"  name="username" id="username"> | ||||||
|  |           </div> | ||||||
|  |         </div> | ||||||
|  | 
 | ||||||
|  |         <div class="field"> | ||||||
|  |           <label class="label" for="password">Password</label> | ||||||
|  |           <div class="control"> | ||||||
|  |             <input class="input" type="password" placeholder="********" name="password" id="password"> | ||||||
|  |           </div> | ||||||
|  |         </div> | ||||||
|  |         <button class="button is-primary">Sign in</button> | ||||||
|  |         {{ end }} | ||||||
|  |     </form> | ||||||
|  |     </div> | ||||||
|  |   </section> | ||||||
|  | 
 | ||||||
|  |   <footer class="footer has-text-grey has-background-light is-size-7"> | ||||||
|  |     <div class="content has-text-centered"> | ||||||
|  |     	{{ if eq .Footer "-" }} | ||||||
|  |     	{{ else if eq .Footer ""}} | ||||||
|  |     	<p>Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy" class="has-text-grey">OAuth2 Proxy</a> version {{.Version}}</p> | ||||||
|  |     	{{ else }} | ||||||
|  |     	<p>{{.Footer}}</p> | ||||||
|  |     	{{ end }} | ||||||
|  |     </div> | ||||||
|  | 	</footer> | ||||||
|  | 
 | ||||||
|  |   </body> | ||||||
|  | </html> | ||||||
|  | {{end}}` | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // LoadTemplates adds the Sign In and Error templates from the custom template
 | ||||||
|  | // directory, or uses the defaults if they do not exist or the custom directory
 | ||||||
|  | // is not provided.
 | ||||||
|  | func LoadTemplates(customDir string) (*template.Template, error) { | ||||||
|  | 	t := template.New("").Funcs(template.FuncMap{ | ||||||
|  | 		"ToUpper": strings.ToUpper, | ||||||
|  | 		"ToLower": strings.ToLower, | ||||||
|  | 	}) | ||||||
|  | 	var err error | ||||||
|  | 	t, err = addTemplate(t, customDir, signInTemplateName, defaultSignInTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not add Sign In template: %v", err) | ||||||
|  | 	} | ||||||
|  | 	t, err = addTemplate(t, customDir, errorTemplateName, defaultErrorTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not add Error template: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return t, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // addTemplate will add the template from the custom directory if provided,
 | ||||||
|  | // else it will add the default template.
 | ||||||
|  | func addTemplate(t *template.Template, customDir, fileName, defaultTemplate string) (*template.Template, error) { | ||||||
|  | 	filePath := filepath.Join(customDir, fileName) | ||||||
|  | 	if customDir != "" && isFile(filePath) { | ||||||
|  | 		t, err := t.ParseFiles(filePath) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("failed to parse template %s: %v", filePath, err) | ||||||
|  | 		} | ||||||
|  | 		return t, nil | ||||||
|  | 	} | ||||||
|  | 	t, err := t.Parse(defaultTemplate) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// This should not happen.
 | ||||||
|  | 		// Default templates should be tested and so should never fail to parse.
 | ||||||
|  | 		logger.Panic("Could not parse defaultTemplate: ", err) | ||||||
|  | 	} | ||||||
|  | 	return t, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // isFile checks if the file exists and checks whether it is a regular file.
 | ||||||
|  | // If either of these fail then it cannot be used as a template file.
 | ||||||
|  | func isFile(fileName string) bool { | ||||||
|  | 	info, err := os.Stat(fileName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Could not load file %s: %v, will use default template", fileName, err) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return info.Mode().IsRegular() | ||||||
|  | } | ||||||
|  | @ -0,0 +1,199 @@ | ||||||
|  | package app | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"html/template" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Templates", func() { | ||||||
|  | 	var customDir string | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		var err error | ||||||
|  | 		customDir, err = ioutil.TempDir("", "oauth2-proxy-templates-test") | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}` | ||||||
|  | 		signInFile := filepath.Join(customDir, signInTemplateName) | ||||||
|  | 		Expect(ioutil.WriteFile(signInFile, []byte(templateHTML), 0666)).To(Succeed()) | ||||||
|  | 		errorFile := filepath.Join(customDir, errorTemplateName) | ||||||
|  | 		Expect(ioutil.WriteFile(errorFile, []byte(templateHTML), 0666)).To(Succeed()) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	AfterEach(func() { | ||||||
|  | 		Expect(os.RemoveAll(customDir)).To(Succeed()) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("LoadTemplates", func() { | ||||||
|  | 		var data interface{} | ||||||
|  | 		var t *template.Template | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			data = struct { | ||||||
|  | 				// For default templates
 | ||||||
|  | 				ProxyPrefix string | ||||||
|  | 				Redirect    string | ||||||
|  | 				Footer      string | ||||||
|  | 
 | ||||||
|  | 				// For default sign_in template
 | ||||||
|  | 				SignInMessage string | ||||||
|  | 				ProviderName  string | ||||||
|  | 				CustomLogin   bool | ||||||
|  | 
 | ||||||
|  | 				// For default error template
 | ||||||
|  | 				StatusCode int | ||||||
|  | 				Title      string | ||||||
|  | 				Message    string | ||||||
|  | 
 | ||||||
|  | 				// For custom templates
 | ||||||
|  | 				TestString string | ||||||
|  | 			}{ | ||||||
|  | 				ProxyPrefix: "<proxy-prefix>", | ||||||
|  | 				Redirect:    "<redirect>", | ||||||
|  | 				Footer:      "<footer>", | ||||||
|  | 
 | ||||||
|  | 				SignInMessage: "<sign-in-message>", | ||||||
|  | 				ProviderName:  "<provider-name>", | ||||||
|  | 				CustomLogin:   false, | ||||||
|  | 
 | ||||||
|  | 				StatusCode: 404, | ||||||
|  | 				Title:      "<title>", | ||||||
|  | 				Message:    "<message>", | ||||||
|  | 
 | ||||||
|  | 				TestString: "Testing", | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("With no custom directory", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				t, err = LoadTemplates("") | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Use the default sign_in page", func() { | ||||||
|  | 				buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 				Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 				Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Use the default error page", func() { | ||||||
|  | 				buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 				Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 				Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("With a custom directory", func() { | ||||||
|  | 			Context("With both templates", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					var err error | ||||||
|  | 					t, err = LoadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the custom sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the custom error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("With no error template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					Expect(os.Remove(filepath.Join(customDir, errorTemplateName))).To(Succeed()) | ||||||
|  | 
 | ||||||
|  | 					var err error | ||||||
|  | 					t, err = LoadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the custom sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the default error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("With no sign_in template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					Expect(os.Remove(filepath.Join(customDir, signInTemplateName))).To(Succeed()) | ||||||
|  | 
 | ||||||
|  | 					var err error | ||||||
|  | 					t, err = LoadTemplates(customDir) | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the default sign_in page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, signInTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(HavePrefix("\n<!DOCTYPE html>")) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Use the custom error page", func() { | ||||||
|  | 					buf := bytes.NewBuffer([]byte{}) | ||||||
|  | 					Expect(t.ExecuteTemplate(buf, errorTemplateName, data)).To(Succeed()) | ||||||
|  | 					Expect(buf.String()).To(Equal("Testing testing TESTING")) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("With an invalid sign_in template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					signInFile := filepath.Join(customDir, signInTemplateName) | ||||||
|  | 					Expect(ioutil.WriteFile(signInFile, []byte("{{"), 0666)) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Should return an error when loading templates", func() { | ||||||
|  | 					t, err := LoadTemplates(customDir) | ||||||
|  | 					Expect(err).To(MatchError(HavePrefix("could not add Sign In template:"))) | ||||||
|  | 					Expect(t).To(BeNil()) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("With an invalid error template", func() { | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					errorFile := filepath.Join(customDir, errorTemplateName) | ||||||
|  | 					Expect(ioutil.WriteFile(errorFile, []byte("{{"), 0666)) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				It("Should return an error when loading templates", func() { | ||||||
|  | 					t, err := LoadTemplates(customDir) | ||||||
|  | 					Expect(err).To(MatchError(HavePrefix("could not add Error template:"))) | ||||||
|  | 					Expect(t).To(BeNil()) | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("isFile", func() { | ||||||
|  | 		It("with a valid file", func() { | ||||||
|  | 			Expect(isFile(filepath.Join(customDir, signInTemplateName))).To(BeTrue()) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("with a directory", func() { | ||||||
|  | 			Expect(isFile(customDir)).To(BeFalse()) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("with an invalid file", func() { | ||||||
|  | 			Expect(isFile(filepath.Join(customDir, "does_not_exist.html"))).To(BeFalse()) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
							
								
								
									
										187
									
								
								templates.go
								
								
								
								
							
							
						
						
									
										187
									
								
								templates.go
								
								
								
								
							|  | @ -1,187 +0,0 @@ | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"html/template" |  | ||||||
| 	"path" |  | ||||||
| 	"strings" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func loadTemplates(dir string) *template.Template { |  | ||||||
| 	if dir == "" { |  | ||||||
| 		return getTemplates() |  | ||||||
| 	} |  | ||||||
| 	logger.Printf("using custom template directory %q", dir) |  | ||||||
| 	funcMap := template.FuncMap{ |  | ||||||
| 		"ToUpper": strings.ToUpper, |  | ||||||
| 		"ToLower": strings.ToLower, |  | ||||||
| 	} |  | ||||||
| 	t, err := template.New("").Funcs(funcMap).ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html")) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Fatalf("failed parsing template %s", err) |  | ||||||
| 	} |  | ||||||
| 	return t |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func getTemplates() *template.Template { |  | ||||||
| 	t, err := template.New("foo").Parse(`{{define "sign_in.html"}} |  | ||||||
| <!DOCTYPE html> |  | ||||||
| <html lang="en" charset="utf-8"> |  | ||||||
| <head> |  | ||||||
| 	<title>Sign In</title> |  | ||||||
| 	<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no"> |  | ||||||
| 	<style> |  | ||||||
| 	body { |  | ||||||
| 		font-family: "Helvetica Neue",Helvetica,Arial,sans-serif; |  | ||||||
| 		font-size: 14px; |  | ||||||
| 		line-height: 1.42857143; |  | ||||||
| 		color: #333; |  | ||||||
| 		background: #f0f0f0; |  | ||||||
| 	} |  | ||||||
| 	.signin { |  | ||||||
| 		display:block; |  | ||||||
| 		margin:20px auto; |  | ||||||
| 		max-width:400px; |  | ||||||
| 		background: #fff; |  | ||||||
| 		border:1px solid #ccc; |  | ||||||
| 		border-radius: 10px; |  | ||||||
| 		padding: 20px; |  | ||||||
| 	} |  | ||||||
| 	.center { |  | ||||||
| 		text-align:center; |  | ||||||
| 	} |  | ||||||
| 	.btn { |  | ||||||
| 		color: #fff; |  | ||||||
| 		background-color: #428bca; |  | ||||||
| 		border: 1px solid #357ebd; |  | ||||||
| 		-webkit-border-radius: 4; |  | ||||||
| 		-moz-border-radius: 4; |  | ||||||
| 		border-radius: 4px; |  | ||||||
| 		font-size: 14px; |  | ||||||
| 		padding: 6px 12px; |  | ||||||
| 	  	text-decoration: none; |  | ||||||
| 		cursor: pointer; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	.btn:hover { |  | ||||||
| 		background-color: #3071a9; |  | ||||||
| 		border-color: #285e8e; |  | ||||||
| 		text-decoration: none; |  | ||||||
| 	} |  | ||||||
| 	label { |  | ||||||
| 		display: inline-block; |  | ||||||
| 		max-width: 100%; |  | ||||||
| 		margin-bottom: 5px; |  | ||||||
| 		font-weight: 700; |  | ||||||
| 	} |  | ||||||
| 	input { |  | ||||||
| 		display: block; |  | ||||||
| 		width: 100%; |  | ||||||
| 		height: 34px; |  | ||||||
| 		padding: 6px 12px; |  | ||||||
| 		font-size: 14px; |  | ||||||
| 		line-height: 1.42857143; |  | ||||||
| 		color: #555; |  | ||||||
| 		background-color: #fff; |  | ||||||
| 		background-image: none; |  | ||||||
| 		border: 1px solid #ccc; |  | ||||||
| 		border-radius: 4px; |  | ||||||
| 		-webkit-box-shadow: inset 0 1px 1px rgba(0,0,0,.075); |  | ||||||
| 		box-shadow: inset 0 1px 1px rgba(0,0,0,.075); |  | ||||||
| 		-webkit-transition: border-color ease-in-out .15s,-webkit-box-shadow ease-in-out .15s; |  | ||||||
| 		-o-transition: border-color ease-in-out .15s,box-shadow ease-in-out .15s; |  | ||||||
| 		transition: border-color ease-in-out .15s,box-shadow ease-in-out .15s; |  | ||||||
| 		margin:0; |  | ||||||
| 		box-sizing: border-box; |  | ||||||
| 	} |  | ||||||
| 	footer { |  | ||||||
| 		display:block; |  | ||||||
| 		font-size:10px; |  | ||||||
| 		color:#aaa; |  | ||||||
| 		text-align:center; |  | ||||||
| 		margin-bottom:10px; |  | ||||||
| 	} |  | ||||||
| 	footer a { |  | ||||||
| 		display:inline-block; |  | ||||||
| 		height:25px; |  | ||||||
| 		line-height:25px; |  | ||||||
| 		color:#aaa; |  | ||||||
| 		text-decoration:underline; |  | ||||||
| 	} |  | ||||||
| 	footer a:hover { |  | ||||||
| 		color:#aaa; |  | ||||||
| 	} |  | ||||||
| 	</style> |  | ||||||
| </head> |  | ||||||
| <body> |  | ||||||
| 	<div class="signin center"> |  | ||||||
| 	<form method="GET" action="{{.ProxyPrefix}}/start"> |  | ||||||
| 	<input type="hidden" name="rd" value="{{.Redirect}}"> |  | ||||||
| 	{{ if .SignInMessage }} |  | ||||||
| 	<p>{{.SignInMessage}}</p> |  | ||||||
| 	{{ end}} |  | ||||||
| 	<button type="submit" class="btn">Sign in with {{.ProviderName}}</button><br/> |  | ||||||
| 	</form> |  | ||||||
| 	</div> |  | ||||||
| 
 |  | ||||||
| 	{{ if .CustomLogin }} |  | ||||||
| 	<div class="signin"> |  | ||||||
| 	<form method="POST" action="{{.ProxyPrefix}}/sign_in"> |  | ||||||
| 		<input type="hidden" name="rd" value="{{.Redirect}}"> |  | ||||||
| 		<label for="username">Username:</label><input type="text" name="username" id="username" size="10"><br/> |  | ||||||
| 		<label for="password">Password:</label><input type="password" name="password" id="password" size="10"><br/> |  | ||||||
| 		<button type="submit" class="btn">Sign In</button> |  | ||||||
| 	</form> |  | ||||||
| 	</div> |  | ||||||
| 	{{ end }} |  | ||||||
| 	<script> |  | ||||||
| 		if (window.location.hash) { |  | ||||||
| 			(function() { |  | ||||||
| 				var inputs = document.getElementsByName('rd'); |  | ||||||
| 				for (var i = 0; i < inputs.length; i++) { |  | ||||||
| 					// Add hash, but make sure it is only added once
 |  | ||||||
| 					var idx = inputs[i].value.indexOf('#'); |  | ||||||
| 					if (idx >= 0) { |  | ||||||
| 						// Remove existing hash from URL
 |  | ||||||
| 						inputs[i].value = inputs[i].value.substr(0, idx); |  | ||||||
| 					} |  | ||||||
| 					inputs[i].value += window.location.hash; |  | ||||||
| 				} |  | ||||||
| 			})(); |  | ||||||
| 		} |  | ||||||
| 	</script> |  | ||||||
| 	<footer> |  | ||||||
| 	{{ if eq .Footer "-" }} |  | ||||||
| 	{{ else if eq .Footer ""}} |  | ||||||
| 	Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy">OAuth2 Proxy</a> version {{.Version}} |  | ||||||
| 	{{ else }} |  | ||||||
| 	{{.Footer}} |  | ||||||
| 	{{ end }} |  | ||||||
| 	</footer> |  | ||||||
| </body> |  | ||||||
| </html> |  | ||||||
| {{end}}`) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Fatalf("failed parsing template %s", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	t, err = t.Parse(`{{define "error.html"}} |  | ||||||
| <!DOCTYPE html> |  | ||||||
| <html lang="en" charset="utf-8"> |  | ||||||
| <head> |  | ||||||
| 	<title>{{.Title}}</title> |  | ||||||
| 	<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no"> |  | ||||||
| </head> |  | ||||||
| <body> |  | ||||||
| 	<h2>{{.Title}}</h2> |  | ||||||
| 	<p>{{.Message}}</p> |  | ||||||
| 	<hr> |  | ||||||
| 	<p><a href="{{.ProxyPrefix}}/sign_in">Sign In</a></p> |  | ||||||
| </body> |  | ||||||
| </html>{{end}}`) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Fatalf("failed parsing template %s", err) |  | ||||||
| 	} |  | ||||||
| 	return t |  | ||||||
| } |  | ||||||
|  | @ -1,62 +0,0 @@ | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"io/ioutil" |  | ||||||
| 	"log" |  | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 	"testing" |  | ||||||
| 
 |  | ||||||
| 	"github.com/stretchr/testify/assert" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func TestLoadTemplates(t *testing.T) { |  | ||||||
| 	data := struct { |  | ||||||
| 		TestString string |  | ||||||
| 	}{ |  | ||||||
| 		TestString: "Testing", |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	templates := loadTemplates("") |  | ||||||
| 	assert.NotEqual(t, templates, nil) |  | ||||||
| 
 |  | ||||||
| 	var defaultSignin bytes.Buffer |  | ||||||
| 	templates.ExecuteTemplate(&defaultSignin, "sign_in.html", data) |  | ||||||
| 	assert.Equal(t, "\n<!DOCTYPE html>", defaultSignin.String()[0:16]) |  | ||||||
| 
 |  | ||||||
| 	var defaultError bytes.Buffer |  | ||||||
| 	templates.ExecuteTemplate(&defaultError, "error.html", data) |  | ||||||
| 	assert.Equal(t, "\n<!DOCTYPE html>", defaultError.String()[0:16]) |  | ||||||
| 
 |  | ||||||
| 	dir, err := ioutil.TempDir("", "templatetest") |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer os.RemoveAll(dir) |  | ||||||
| 
 |  | ||||||
| 	templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}` |  | ||||||
| 	signInFile := filepath.Join(dir, "sign_in.html") |  | ||||||
| 	if err := ioutil.WriteFile(signInFile, []byte(templateHTML), 0666); err != nil { |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	errorFile := filepath.Join(dir, "error.html") |  | ||||||
| 	if err := ioutil.WriteFile(errorFile, []byte(templateHTML), 0666); err != nil { |  | ||||||
| 		log.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	templates = loadTemplates(dir) |  | ||||||
| 	assert.NotEqual(t, templates, nil) |  | ||||||
| 
 |  | ||||||
| 	var sitpl bytes.Buffer |  | ||||||
| 	templates.ExecuteTemplate(&sitpl, "sign_in.html", data) |  | ||||||
| 	assert.Equal(t, "Testing testing TESTING", sitpl.String()) |  | ||||||
| 
 |  | ||||||
| 	var errtpl bytes.Buffer |  | ||||||
| 	templates.ExecuteTemplate(&errtpl, "error.html", data) |  | ||||||
| 	assert.Equal(t, "Testing testing TESTING", errtpl.String()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestTemplatesCompile(t *testing.T) { |  | ||||||
| 	templates := getTemplates() |  | ||||||
| 	assert.NotEqual(t, templates, nil) |  | ||||||
| } |  | ||||||
		Loading…
	
		Reference in New Issue