Improve handler vs helper organization in oauthproxy.go
Additionally, convert a lot of helper methods to be private
This commit is contained in:
		
							parent
							
								
									73fc7706bc
								
							
						
					
					
						commit
						fa6a785eaf
					
				
							
								
								
									
										652
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										652
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -31,9 +31,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	httpScheme  = "http" | 	schemeHTTPS     = "https" | ||||||
| 	httpsScheme = "https" |  | ||||||
| 
 |  | ||||||
| 	applicationJSON = "application/json" | 	applicationJSON = "application/json" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -366,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { | ||||||
| 	return routes, nil | 	return routes, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
 |  | ||||||
| // redirect clients to once authenticated
 |  | ||||||
| func (p *OAuthProxy) GetOAuthRedirectURI(host string) string { |  | ||||||
| 	// default to the request Host if not set
 |  | ||||||
| 	if p.redirectURL.Host != "" { |  | ||||||
| 		return p.redirectURL.String() |  | ||||||
| 	} |  | ||||||
| 	u := *p.redirectURL |  | ||||||
| 	if u.Scheme == "" { |  | ||||||
| 		if p.CookieSecure { |  | ||||||
| 			u.Scheme = httpsScheme |  | ||||||
| 		} else { |  | ||||||
| 			u.Scheme = httpScheme |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	u.Host = host |  | ||||||
| 	return u.String() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { |  | ||||||
| 	if code == "" { |  | ||||||
| 		return nil, providers.ErrMissingCode |  | ||||||
| 	} |  | ||||||
| 	redirectURI := p.GetOAuthRedirectURI(host) |  | ||||||
| 	s, err := p.provider.Redeem(ctx, redirectURI, code) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	return s, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { |  | ||||||
| 	var err error |  | ||||||
| 	if s.Email == "" { |  | ||||||
| 		s.Email, err = p.provider.GetEmailAddress(ctx, s) |  | ||||||
| 		if err != nil && !errors.Is(err, providers.ErrNotImplemented) { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return p.provider.EnrichSession(ctx, s) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // MakeCSRFCookie creates a cookie for CSRF
 | // MakeCSRFCookie creates a cookie for CSRF
 | ||||||
| func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||||
| 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) | 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) | ||||||
|  | @ -466,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s | ||||||
| 	return p.sessionStore.Save(rw, req, s) | 	return p.sessionStore.Save(rw, req, s) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // IsValidRedirect checks whether the redirect URL is whitelisted
 | ||||||
|  | func (p *OAuthProxy) IsValidRedirect(redirect string) bool { | ||||||
|  | 	switch { | ||||||
|  | 	case redirect == "": | ||||||
|  | 		// The user didn't specify a redirect, should fallback to `/`
 | ||||||
|  | 		return false | ||||||
|  | 	case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): | ||||||
|  | 		return true | ||||||
|  | 	case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): | ||||||
|  | 		redirectURL, err := url.Parse(redirect) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 		redirectHostname := redirectURL.Hostname() | ||||||
|  | 
 | ||||||
|  | 		for _, domain := range p.whitelistDomains { | ||||||
|  | 			domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) | ||||||
|  | 			if domainHostname == "" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { | ||||||
|  | 				// the domain names match, now validate the ports
 | ||||||
|  | 				// if the whitelisted domain's port is '*', allow all ports
 | ||||||
|  | 				// if the whitelisted domain contains a specific port, only allow that port
 | ||||||
|  | 				// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
 | ||||||
|  | 				redirectPort := redirectURL.Port() | ||||||
|  | 				if (domainPort == "*") || | ||||||
|  | 					(domainPort == redirectPort) || | ||||||
|  | 					(domainPort == "" && redirectPort == "") { | ||||||
|  | 					return true | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) | ||||||
|  | 		return false | ||||||
|  | 	default: | ||||||
|  | 		logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 	p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 	if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { | ||||||
|  | 		prepareNoCache(rw) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch path := req.URL.Path; { | ||||||
|  | 	case path == p.RobotsPath: | ||||||
|  | 		p.RobotsTxt(rw) | ||||||
|  | 	case p.IsAllowedRequest(req): | ||||||
|  | 		p.SkipAuthProxy(rw, req) | ||||||
|  | 	case path == p.SignInPath: | ||||||
|  | 		p.SignIn(rw, req) | ||||||
|  | 	case path == p.SignOutPath: | ||||||
|  | 		p.SignOut(rw, req) | ||||||
|  | 	case path == p.OAuthStartPath: | ||||||
|  | 		p.OAuthStart(rw, req) | ||||||
|  | 	case path == p.OAuthCallbackPath: | ||||||
|  | 		p.OAuthCallback(rw, req) | ||||||
|  | 	case path == p.AuthOnlyPath: | ||||||
|  | 		p.AuthOnly(rw, req) | ||||||
|  | 	case path == p.UserInfoPath: | ||||||
|  | 		p.UserInfo(rw, req) | ||||||
|  | 	default: | ||||||
|  | 		p.Proxy(rw, req) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // RobotsTxt disallows scraping pages from the OAuthProxy
 | // RobotsTxt disallows scraping pages from the OAuthProxy
 | ||||||
| func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { | func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { | ||||||
| 	_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | 	_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | ||||||
|  | @ -496,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // IsAllowedRequest is used to check if auth should be skipped for this request
 | ||||||
|  | func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { | ||||||
|  | 	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" | ||||||
|  | 	return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsAllowedRoute is used to check if the request method & path is allowed without auth
 | ||||||
|  | func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { | ||||||
|  | 	for _, route := range p.allowedRoutes { | ||||||
|  | 		if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // isTrustedIP is used to check if a request comes from a trusted client IP address.
 | ||||||
|  | func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { | ||||||
|  | 	if p.trustedIPs == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) | ||||||
|  | 		// Possibly spoofed X-Real-IP header
 | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if remoteAddr == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return p.trustedIPs.Has(remoteAddr) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // SignInPage writes the sing in template to the response
 | // SignInPage writes the sing in template to the response
 | ||||||
| func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { | func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { | ||||||
| 	prepareNoCache(rw) | 	prepareNoCache(rw) | ||||||
|  | @ -507,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 	} | 	} | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
| 
 | 
 | ||||||
| 	redirectURL, err := p.GetAppRedirect(req) | 	redirectURL, err := p.getAppRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
|  | @ -566,276 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) { | ||||||
| 	return "", false | 	return "", false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetAppRedirect determines the full URL or URI path to redirect clients to
 |  | ||||||
| // once authenticated with the OAuthProxy
 |  | ||||||
| // Strategy priority (first legal result is used):
 |  | ||||||
| // - `rd` querysting parameter
 |  | ||||||
| // - `X-Auth-Request-Redirect` header
 |  | ||||||
| // - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
 |  | ||||||
| // - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
 |  | ||||||
| // - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
 |  | ||||||
| // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
 |  | ||||||
| // - `/`
 |  | ||||||
| func (p *OAuthProxy) GetAppRedirect(req *http.Request) (string, error) { |  | ||||||
| 	err := req.ParseForm() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// These redirect getter functions are strategies ordered by priority
 |  | ||||||
| 	// for figuring out the redirect URL.
 |  | ||||||
| 	type redirectGetter func(req *http.Request) string |  | ||||||
| 	for _, rdGetter := range []redirectGetter{ |  | ||||||
| 		p.getRdQuerystringRedirect, |  | ||||||
| 		p.getXAuthRequestRedirect, |  | ||||||
| 		p.getXForwardedHeadersRedirect, |  | ||||||
| 		p.getURIRedirect, |  | ||||||
| 	} { |  | ||||||
| 		if redirect := rdGetter(req); redirect != "" { |  | ||||||
| 			return redirect, nil |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return "/", nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func isForwardedRequest(req *http.Request) bool { |  | ||||||
| 	return requestutil.IsProxied(req) && |  | ||||||
| 		req.Host != requestutil.GetRequestHost(req) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) hasProxyPrefix(path string) bool { |  | ||||||
| 	return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // getRdQuerystringRedirect handles this GetAppRedirect strategy:
 |  | ||||||
| // - `rd` querysting parameter
 |  | ||||||
| func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { |  | ||||||
| 	redirect := req.Form.Get("rd") |  | ||||||
| 	if p.IsValidRedirect(redirect) { |  | ||||||
| 		return redirect |  | ||||||
| 	} |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // getXAuthRequestRedirect handles this GetAppRedirect strategy:
 |  | ||||||
| // - `X-Auth-Request-Redirect` Header
 |  | ||||||
| func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { |  | ||||||
| 	redirect := req.Header.Get("X-Auth-Request-Redirect") |  | ||||||
| 	if p.IsValidRedirect(redirect) { |  | ||||||
| 		return redirect |  | ||||||
| 	} |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // getXForwardedHeadersRedirect handles these GetAppRedirect strategies:
 |  | ||||||
| // - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
 |  | ||||||
| // - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
 |  | ||||||
| func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { |  | ||||||
| 	if !isForwardedRequest(req) { |  | ||||||
| 		return "" |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	uri := requestutil.GetRequestURI(req) |  | ||||||
| 	if p.hasProxyPrefix(uri) { |  | ||||||
| 		uri = "/" |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	redirect := fmt.Sprintf( |  | ||||||
| 		"%s://%s%s", |  | ||||||
| 		requestutil.GetRequestProto(req), |  | ||||||
| 		requestutil.GetRequestHost(req), |  | ||||||
| 		uri, |  | ||||||
| 	) |  | ||||||
| 
 |  | ||||||
| 	if p.IsValidRedirect(redirect) { |  | ||||||
| 		return redirect |  | ||||||
| 	} |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // getURIRedirect handles these GetAppRedirect strategies:
 |  | ||||||
| // - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
 |  | ||||||
| // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
 |  | ||||||
| // - `/`
 |  | ||||||
| func (p *OAuthProxy) getURIRedirect(req *http.Request) string { |  | ||||||
| 	redirect := requestutil.GetRequestURI(req) |  | ||||||
| 	if !p.IsValidRedirect(redirect) { |  | ||||||
| 		redirect = req.URL.RequestURI() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if p.hasProxyPrefix(redirect) { |  | ||||||
| 		return "/" |  | ||||||
| 	} |  | ||||||
| 	return redirect |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // splitHostPort separates host and port. If the port is not valid, it returns
 |  | ||||||
| // the entire input as host, and it doesn't check the validity of the host.
 |  | ||||||
| // Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
 |  | ||||||
| // *** taken from net/url, modified validOptionalPort() to accept ":*"
 |  | ||||||
| func splitHostPort(hostport string) (host, port string) { |  | ||||||
| 	host = hostport |  | ||||||
| 
 |  | ||||||
| 	colon := strings.LastIndexByte(host, ':') |  | ||||||
| 	if colon != -1 && validOptionalPort(host[colon:]) { |  | ||||||
| 		host, port = host[:colon], host[colon+1:] |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { |  | ||||||
| 		host = host[1 : len(host)-1] |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // validOptionalPort reports whether port is either an empty string
 |  | ||||||
| // or matches /^:\d*$/
 |  | ||||||
| // *** taken from net/url, modified to accept ":*"
 |  | ||||||
| func validOptionalPort(port string) bool { |  | ||||||
| 	if port == "" || port == ":*" { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	if port[0] != ':' { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 	for _, b := range port[1:] { |  | ||||||
| 		if b < '0' || b > '9' { |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return true |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // IsValidRedirect checks whether the redirect URL is whitelisted
 |  | ||||||
| func (p *OAuthProxy) IsValidRedirect(redirect string) bool { |  | ||||||
| 	switch { |  | ||||||
| 	case redirect == "": |  | ||||||
| 		// The user didn't specify a redirect, should fallback to `/`
 |  | ||||||
| 		return false |  | ||||||
| 	case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): |  | ||||||
| 		return true |  | ||||||
| 	case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): |  | ||||||
| 		redirectURL, err := url.Parse(redirect) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 		redirectHostname := redirectURL.Hostname() |  | ||||||
| 
 |  | ||||||
| 		for _, domain := range p.whitelistDomains { |  | ||||||
| 			domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) |  | ||||||
| 			if domainHostname == "" { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { |  | ||||||
| 				// the domain names match, now validate the ports
 |  | ||||||
| 				// if the whitelisted domain's port is '*', allow all ports
 |  | ||||||
| 				// if the whitelisted domain contains a specific port, only allow that port
 |  | ||||||
| 				// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
 |  | ||||||
| 				redirectPort := redirectURL.Port() |  | ||||||
| 				if (domainPort == "*") || |  | ||||||
| 					(domainPort == redirectPort) || |  | ||||||
| 					(domainPort == "" && redirectPort == "") { |  | ||||||
| 					return true |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) |  | ||||||
| 		return false |  | ||||||
| 	default: |  | ||||||
| 		logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // IsAllowedRequest is used to check if auth should be skipped for this request
 |  | ||||||
| func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { |  | ||||||
| 	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" |  | ||||||
| 	return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // IsAllowedRoute is used to check if the request method & path is allowed without auth
 |  | ||||||
| func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { |  | ||||||
| 	for _, route := range p.allowedRoutes { |  | ||||||
| 		if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
 |  | ||||||
| var noCacheHeaders = map[string]string{ |  | ||||||
| 	"Expires":         time.Unix(0, 0).Format(time.RFC1123), |  | ||||||
| 	"Cache-Control":   "no-cache, no-store, must-revalidate, max-age=0", |  | ||||||
| 	"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // prepareNoCache prepares headers for preventing browser caching.
 |  | ||||||
| func prepareNoCache(w http.ResponseWriter) { |  | ||||||
| 	// Set NoCache headers
 |  | ||||||
| 	for k, v := range noCacheHeaders { |  | ||||||
| 		w.Header().Set(k, v) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // IsTrustedIP is used to check if a request comes from a trusted client IP address.
 |  | ||||||
| func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool { |  | ||||||
| 	if p.trustedIPs == nil { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) |  | ||||||
| 		// Possibly spoofed X-Real-IP header
 |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if remoteAddr == nil { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return p.trustedIPs.Has(remoteAddr) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { |  | ||||||
| 	p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { |  | ||||||
| 	if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { |  | ||||||
| 		prepareNoCache(rw) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	switch path := req.URL.Path; { |  | ||||||
| 	case path == p.RobotsPath: |  | ||||||
| 		p.RobotsTxt(rw) |  | ||||||
| 	case p.IsAllowedRequest(req): |  | ||||||
| 		p.SkipAuthProxy(rw, req) |  | ||||||
| 	case path == p.SignInPath: |  | ||||||
| 		p.SignIn(rw, req) |  | ||||||
| 	case path == p.SignOutPath: |  | ||||||
| 		p.SignOut(rw, req) |  | ||||||
| 	case path == p.OAuthStartPath: |  | ||||||
| 		p.OAuthStart(rw, req) |  | ||||||
| 	case path == p.OAuthCallbackPath: |  | ||||||
| 		p.OAuthCallback(rw, req) |  | ||||||
| 	case path == p.AuthOnlyPath: |  | ||||||
| 		p.AuthOnly(rw, req) |  | ||||||
| 	case path == p.UserInfoPath: |  | ||||||
| 		p.UserInfo(rw, req) |  | ||||||
| 	default: |  | ||||||
| 		p.Proxy(rw, req) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // SignIn serves a page prompting users to sign in
 | // SignIn serves a page prompting users to sign in
 | ||||||
| func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	redirect, err := p.GetAppRedirect(req) | 	redirect, err := p.getAppRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
|  | @ -893,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { | ||||||
| 
 | 
 | ||||||
| // SignOut sends a response to clear the authentication cookie
 | // SignOut sends a response to clear the authentication cookie
 | ||||||
| func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	redirect, err := p.GetAppRedirect(req) | 	redirect, err := p.getAppRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
|  | @ -918,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	p.SetCSRFCookie(rw, req, nonce) | 	p.SetCSRFCookie(rw, req, nonce) | ||||||
| 	redirect, err := p.GetAppRedirect(req) | 	redirect, err := p.getAppRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | 		logger.Errorf("Error obtaining redirect: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	redirectURI := p.GetOAuthRedirectURI(requestutil.GetRequestHost(req)) | 	redirectURI := p.getOAuthRedirectURI(req) | ||||||
| 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) | 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -947,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	session, err := p.redeemCode(req.Context(), requestutil.GetRequestHost(req), req.Form.Get("code")) | 	session, err := p.redeemCode(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) | 		logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | ||||||
|  | @ -1006,6 +805,32 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 	code := req.Form.Get("code") | ||||||
|  | 	if code == "" { | ||||||
|  | 		return nil, providers.ErrMissingCode | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	redirectURI := p.getOAuthRedirectURI(req) | ||||||
|  | 	s, err := p.provider.Redeem(req.Context(), redirectURI, code) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return s, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { | ||||||
|  | 	var err error | ||||||
|  | 	if s.Email == "" { | ||||||
|  | 		s.Email, err = p.provider.GetEmailAddress(ctx, s) | ||||||
|  | 		if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return p.provider.EnrichSession(ctx, s) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // AuthOnly checks whether the user is currently logged in (both authentication
 | // AuthOnly checks whether the user is currently logged in (both authentication
 | ||||||
| // and optional authorization).
 | // and optional authorization).
 | ||||||
| func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | @ -1023,7 +848,7 @@ func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// we are authenticated
 | 	// we are authenticated
 | ||||||
| 	p.addHeadersForProxying(rw, req, session) | 	p.addHeadersForProxying(rw, session) | ||||||
| 	p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | 	p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		rw.WriteHeader(http.StatusAccepted) | 		rw.WriteHeader(http.StatusAccepted) | ||||||
| 	})).ServeHTTP(rw, req) | 	})).ServeHTTP(rw, req) | ||||||
|  | @ -1041,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	switch err { | 	switch err { | ||||||
| 	case nil: | 	case nil: | ||||||
| 		// we are authenticated
 | 		// we are authenticated
 | ||||||
| 		p.addHeadersForProxying(rw, req, session) | 		p.addHeadersForProxying(rw, session) | ||||||
| 		p.headersChain.Then(p.serveMux).ServeHTTP(rw, req) | 		p.headersChain.Then(p.serveMux).ServeHTTP(rw, req) | ||||||
| 	case ErrNeedsLogin: | 	case ErrNeedsLogin: | ||||||
| 		// we need to send the user to a login screen
 | 		// we need to send the user to a login screen
 | ||||||
| 		if isAjax(req) { | 		if isAjax(req) { | ||||||
| 			// no point redirecting an AJAX request
 | 			// no point redirecting an AJAX request
 | ||||||
| 			p.ErrorJSON(rw, http.StatusUnauthorized) | 			p.errorJSON(rw, http.StatusUnauthorized) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -1066,7 +891,184 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, | 		p.ErrorPage(rw, http.StatusInternalServerError, | ||||||
| 			"Internal Error", "Internal Error") | 			"Internal Error", "Internal Error") | ||||||
| 	} | 	} | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
|  | // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
 | ||||||
|  | var noCacheHeaders = map[string]string{ | ||||||
|  | 	"Expires":         time.Unix(0, 0).Format(time.RFC1123), | ||||||
|  | 	"Cache-Control":   "no-cache, no-store, must-revalidate, max-age=0", | ||||||
|  | 	"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // prepareNoCache prepares headers for preventing browser caching.
 | ||||||
|  | func prepareNoCache(w http.ResponseWriter) { | ||||||
|  | 	// Set NoCache headers
 | ||||||
|  | 	for k, v := range noCacheHeaders { | ||||||
|  | 		w.Header().Set(k, v) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
 | ||||||
|  | // redirect clients to once authenticated.
 | ||||||
|  | // This is usually the OAuthProxy callback URL.
 | ||||||
|  | func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string { | ||||||
|  | 	// if `p.redirectURL` already has a host, return it
 | ||||||
|  | 	if p.redirectURL.Host != "" { | ||||||
|  | 		return p.redirectURL.String() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Otherwise figure out the scheme + host from the request
 | ||||||
|  | 	rd := *p.redirectURL | ||||||
|  | 	rd.Host = requestutil.GetRequestHost(req) | ||||||
|  | 	rd.Scheme = requestutil.GetRequestProto(req) | ||||||
|  | 
 | ||||||
|  | 	// If CookieSecure is true, return `https` no matter what
 | ||||||
|  | 	// Not all reverse proxies set X-Forwarded-Proto
 | ||||||
|  | 	if p.CookieSecure { | ||||||
|  | 		rd.Scheme = schemeHTTPS | ||||||
|  | 	} | ||||||
|  | 	return rd.String() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getAppRedirect determines the full URL or URI path to redirect clients to
 | ||||||
|  | // once authenticated with the OAuthProxy
 | ||||||
|  | // Strategy priority (first legal result is used):
 | ||||||
|  | // - `rd` querysting parameter
 | ||||||
|  | // - `X-Auth-Request-Redirect` header
 | ||||||
|  | // - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
 | ||||||
|  | // - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
 | ||||||
|  | // - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
 | ||||||
|  | // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
 | ||||||
|  | // - `/`
 | ||||||
|  | func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) { | ||||||
|  | 	err := req.ParseForm() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// These redirect getter functions are strategies ordered by priority
 | ||||||
|  | 	// for figuring out the redirect URL.
 | ||||||
|  | 	type redirectGetter func(req *http.Request) string | ||||||
|  | 	for _, rdGetter := range []redirectGetter{ | ||||||
|  | 		p.getRdQuerystringRedirect, | ||||||
|  | 		p.getXAuthRequestRedirect, | ||||||
|  | 		p.getXForwardedHeadersRedirect, | ||||||
|  | 		p.getURIRedirect, | ||||||
|  | 	} { | ||||||
|  | 		if redirect := rdGetter(req); redirect != "" { | ||||||
|  | 			return redirect, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return "/", nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func isForwardedRequest(req *http.Request) bool { | ||||||
|  | 	return requestutil.IsProxied(req) && | ||||||
|  | 		req.Host != requestutil.GetRequestHost(req) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *OAuthProxy) hasProxyPrefix(path string) bool { | ||||||
|  | 	return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getRdQuerystringRedirect handles this getAppRedirect strategy:
 | ||||||
|  | // - `rd` querysting parameter
 | ||||||
|  | func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { | ||||||
|  | 	redirect := req.Form.Get("rd") | ||||||
|  | 	if p.IsValidRedirect(redirect) { | ||||||
|  | 		return redirect | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getXAuthRequestRedirect handles this getAppRedirect strategy:
 | ||||||
|  | // - `X-Auth-Request-Redirect` Header
 | ||||||
|  | func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { | ||||||
|  | 	redirect := req.Header.Get("X-Auth-Request-Redirect") | ||||||
|  | 	if p.IsValidRedirect(redirect) { | ||||||
|  | 		return redirect | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getXForwardedHeadersRedirect handles these getAppRedirect strategies:
 | ||||||
|  | // - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
 | ||||||
|  | // - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
 | ||||||
|  | func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { | ||||||
|  | 	if !isForwardedRequest(req) { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	uri := requestutil.GetRequestURI(req) | ||||||
|  | 	if p.hasProxyPrefix(uri) { | ||||||
|  | 		uri = "/" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	redirect := fmt.Sprintf( | ||||||
|  | 		"%s://%s%s", | ||||||
|  | 		requestutil.GetRequestProto(req), | ||||||
|  | 		requestutil.GetRequestHost(req), | ||||||
|  | 		uri, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if p.IsValidRedirect(redirect) { | ||||||
|  | 		return redirect | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getURIRedirect handles these getAppRedirect strategies:
 | ||||||
|  | // - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
 | ||||||
|  | // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
 | ||||||
|  | // - `/`
 | ||||||
|  | func (p *OAuthProxy) getURIRedirect(req *http.Request) string { | ||||||
|  | 	redirect := requestutil.GetRequestURI(req) | ||||||
|  | 	if !p.IsValidRedirect(redirect) { | ||||||
|  | 		redirect = req.URL.RequestURI() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if p.hasProxyPrefix(redirect) { | ||||||
|  | 		return "/" | ||||||
|  | 	} | ||||||
|  | 	return redirect | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // splitHostPort separates host and port. If the port is not valid, it returns
 | ||||||
|  | // the entire input as host, and it doesn't check the validity of the host.
 | ||||||
|  | // Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
 | ||||||
|  | // *** taken from net/url, modified validOptionalPort() to accept ":*"
 | ||||||
|  | func splitHostPort(hostport string) (host, port string) { | ||||||
|  | 	host = hostport | ||||||
|  | 
 | ||||||
|  | 	colon := strings.LastIndexByte(host, ':') | ||||||
|  | 	if colon != -1 && validOptionalPort(host[colon:]) { | ||||||
|  | 		host, port = host[:colon], host[colon+1:] | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { | ||||||
|  | 		host = host[1 : len(host)-1] | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // validOptionalPort reports whether port is either an empty string
 | ||||||
|  | // or matches /^:\d*$/
 | ||||||
|  | // *** taken from net/url, modified to accept ":*"
 | ||||||
|  | func validOptionalPort(port string) bool { | ||||||
|  | 	if port == "" || port == ":*" { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if port[0] != ':' { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	for _, b := range port[1:] { | ||||||
|  | 		if b < '0' || b > '9' { | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
 | // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
 | ||||||
|  | @ -1153,7 +1155,7 @@ func extractAllowedGroups(req *http.Request) map[string]struct{} { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // addHeadersForProxying adds the appropriate headers the request / response for proxying
 | // addHeadersForProxying adds the appropriate headers the request / response for proxying
 | ||||||
| func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) { | func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) { | ||||||
| 	if session.Email == "" { | 	if session.Email == "" { | ||||||
| 		rw.Header().Set("GAP-Auth", session.User) | 		rw.Header().Set("GAP-Auth", session.User) | ||||||
| 	} else { | 	} else { | ||||||
|  | @ -1181,8 +1183,8 @@ func isAjax(req *http.Request) bool { | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ErrorJSON returns the error code with an application/json mime type
 | // errorJSON returns the error code with an application/json mime type
 | ||||||
| func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { | func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) { | ||||||
| 	rw.Header().Set("Content-Type", applicationJSON) | 	rw.Header().Set("Content-Type", applicationJSON) | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -415,8 +415,9 @@ func Test_redeemCode(t *testing.T) { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	_, err = proxy.redeemCode(context.Background(), "www.example.com", "") | 	req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||||
| 	assert.Error(t, err) | 	_, err = proxy.redeemCode(req) | ||||||
|  | 	assert.Equal(t, providers.ErrMissingCode, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Test_enrichSession(t *testing.T) { | func Test_enrichSession(t *testing.T) { | ||||||
|  | @ -1749,7 +1750,7 @@ func TestRequestSignature(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestGetRedirect(t *testing.T) { | func Test_getAppRedirect(t *testing.T) { | ||||||
| 	opts := baseTestOptions() | 	opts := baseTestOptions() | ||||||
| 	opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443") | 	opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443") | ||||||
| 	err := validation.Validate(opts) | 	err := validation.Validate(opts) | ||||||
|  | @ -1900,7 +1901,7 @@ func TestGetRedirect(t *testing.T) { | ||||||
| 			req = middleware.AddRequestScope(req, &middleware.RequestScope{ | 			req = middleware.AddRequestScope(req, &middleware.RequestScope{ | ||||||
| 				ReverseProxy: tt.reverseProxy, | 				ReverseProxy: tt.reverseProxy, | ||||||
| 			}) | 			}) | ||||||
| 			redirect, err := proxy.GetAppRedirect(req) | 			redirect, err := proxy.getAppRedirect(req) | ||||||
| 
 | 
 | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
| 			assert.Equal(t, tt.expectedRedirect, redirect) | 			assert.Equal(t, tt.expectedRedirect, redirect) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue