Improve handler vs helper organization in oauthproxy.go
Additionally, convert a lot of helper methods to be private
This commit is contained in:
		
							parent
							
								
									73fc7706bc
								
							
						
					
					
						commit
						fa6a785eaf
					
				
							
								
								
									
										652
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										652
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -31,9 +31,7 @@ import ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	httpScheme  = "http" | ||||
| 	httpsScheme = "https" | ||||
| 
 | ||||
| 	schemeHTTPS     = "https" | ||||
| 	applicationJSON = "application/json" | ||||
| ) | ||||
| 
 | ||||
|  | @ -366,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { | |||
| 	return routes, nil | ||||
| } | ||||
| 
 | ||||
| // GetOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
 | ||||
| // redirect clients to once authenticated
 | ||||
| func (p *OAuthProxy) GetOAuthRedirectURI(host string) string { | ||||
| 	// default to the request Host if not set
 | ||||
| 	if p.redirectURL.Host != "" { | ||||
| 		return p.redirectURL.String() | ||||
| 	} | ||||
| 	u := *p.redirectURL | ||||
| 	if u.Scheme == "" { | ||||
| 		if p.CookieSecure { | ||||
| 			u.Scheme = httpsScheme | ||||
| 		} else { | ||||
| 			u.Scheme = httpScheme | ||||
| 		} | ||||
| 	} | ||||
| 	u.Host = host | ||||
| 	return u.String() | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, providers.ErrMissingCode | ||||
| 	} | ||||
| 	redirectURI := p.GetOAuthRedirectURI(host) | ||||
| 	s, err := p.provider.Redeem(ctx, redirectURI, code) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return s, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { | ||||
| 	var err error | ||||
| 	if s.Email == "" { | ||||
| 		s.Email, err = p.provider.GetEmailAddress(ctx, s) | ||||
| 		if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return p.provider.EnrichSession(ctx, s) | ||||
| } | ||||
| 
 | ||||
| // MakeCSRFCookie creates a cookie for CSRF
 | ||||
| func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||
| 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) | ||||
|  | @ -466,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s | |||
| 	return p.sessionStore.Save(rw, req, s) | ||||
| } | ||||
| 
 | ||||
| // IsValidRedirect checks whether the redirect URL is whitelisted
 | ||||
| func (p *OAuthProxy) IsValidRedirect(redirect string) bool { | ||||
| 	switch { | ||||
| 	case redirect == "": | ||||
| 		// The user didn't specify a redirect, should fallback to `/`
 | ||||
| 		return false | ||||
| 	case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): | ||||
| 		return true | ||||
| 	case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): | ||||
| 		redirectURL, err := url.Parse(redirect) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) | ||||
| 			return false | ||||
| 		} | ||||
| 		redirectHostname := redirectURL.Hostname() | ||||
| 
 | ||||
| 		for _, domain := range p.whitelistDomains { | ||||
| 			domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) | ||||
| 			if domainHostname == "" { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { | ||||
| 				// the domain names match, now validate the ports
 | ||||
| 				// if the whitelisted domain's port is '*', allow all ports
 | ||||
| 				// if the whitelisted domain contains a specific port, only allow that port
 | ||||
| 				// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
 | ||||
| 				redirectPort := redirectURL.Port() | ||||
| 				if (domainPort == "*") || | ||||
| 					(domainPort == redirectPort) || | ||||
| 					(domainPort == "" && redirectPort == "") { | ||||
| 					return true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) | ||||
| 		return false | ||||
| 	default: | ||||
| 		logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { | ||||
| 		prepareNoCache(rw) | ||||
| 	} | ||||
| 
 | ||||
| 	switch path := req.URL.Path; { | ||||
| 	case path == p.RobotsPath: | ||||
| 		p.RobotsTxt(rw) | ||||
| 	case p.IsAllowedRequest(req): | ||||
| 		p.SkipAuthProxy(rw, req) | ||||
| 	case path == p.SignInPath: | ||||
| 		p.SignIn(rw, req) | ||||
| 	case path == p.SignOutPath: | ||||
| 		p.SignOut(rw, req) | ||||
| 	case path == p.OAuthStartPath: | ||||
| 		p.OAuthStart(rw, req) | ||||
| 	case path == p.OAuthCallbackPath: | ||||
| 		p.OAuthCallback(rw, req) | ||||
| 	case path == p.AuthOnlyPath: | ||||
| 		p.AuthOnly(rw, req) | ||||
| 	case path == p.UserInfoPath: | ||||
| 		p.UserInfo(rw, req) | ||||
| 	default: | ||||
| 		p.Proxy(rw, req) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // RobotsTxt disallows scraping pages from the OAuthProxy
 | ||||
| func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { | ||||
| 	_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | ||||
|  | @ -496,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| // IsAllowedRequest is used to check if auth should be skipped for this request
 | ||||
| func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { | ||||
| 	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" | ||||
| 	return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) | ||||
| } | ||||
| 
 | ||||
| // IsAllowedRoute is used to check if the request method & path is allowed without auth
 | ||||
| func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { | ||||
| 	for _, route := range p.allowedRoutes { | ||||
| 		if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // isTrustedIP is used to check if a request comes from a trusted client IP address.
 | ||||
| func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { | ||||
| 	if p.trustedIPs == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) | ||||
| 		// Possibly spoofed X-Real-IP header
 | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if remoteAddr == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	return p.trustedIPs.Has(remoteAddr) | ||||
| } | ||||
| 
 | ||||
| // SignInPage writes the sing in template to the response
 | ||||
| func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { | ||||
| 	prepareNoCache(rw) | ||||
|  | @ -507,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | |||
| 	} | ||||
| 	rw.WriteHeader(code) | ||||
| 
 | ||||
| 	redirectURL, err := p.GetAppRedirect(req) | ||||
| 	redirectURL, err := p.getAppRedirect(req) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | ||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||
|  | @ -566,276 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) { | |||
| 	return "", false | ||||
| } | ||||
| 
 | ||||
| // GetAppRedirect determines the full URL or URI path to redirect clients to
 | ||||
| // once authenticated with the OAuthProxy
 | ||||
| // Strategy priority (first legal result is used):
 | ||||
| // - `rd` querysting parameter
 | ||||
| // - `X-Auth-Request-Redirect` header
 | ||||
| // - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
 | ||||
| // - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
 | ||||
| // - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
 | ||||
| // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
 | ||||
| // - `/`
 | ||||
| func (p *OAuthProxy) GetAppRedirect(req *http.Request) (string, error) { | ||||
| 	err := req.ParseForm() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	// These redirect getter functions are strategies ordered by priority
 | ||||
| 	// for figuring out the redirect URL.
 | ||||
| 	type redirectGetter func(req *http.Request) string | ||||
| 	for _, rdGetter := range []redirectGetter{ | ||||
| 		p.getRdQuerystringRedirect, | ||||
| 		p.getXAuthRequestRedirect, | ||||
| 		p.getXForwardedHeadersRedirect, | ||||
| 		p.getURIRedirect, | ||||
| 	} { | ||||
| 		if redirect := rdGetter(req); redirect != "" { | ||||
| 			return redirect, nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return "/", nil | ||||
| } | ||||
| 
 | ||||
| func isForwardedRequest(req *http.Request) bool { | ||||
| 	return requestutil.IsProxied(req) && | ||||
| 		req.Host != requestutil.GetRequestHost(req) | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) hasProxyPrefix(path string) bool { | ||||
| 	return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) | ||||
| } | ||||
| 
 | ||||
| // getRdQuerystringRedirect handles this GetAppRedirect strategy:
 | ||||
| // - `rd` querysting parameter
 | ||||
| func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { | ||||
| 	redirect := req.Form.Get("rd") | ||||
| 	if p.IsValidRedirect(redirect) { | ||||
| 		return redirect | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // getXAuthRequestRedirect handles this GetAppRedirect strategy:
 | ||||
| // - `X-Auth-Request-Redirect` Header
 | ||||
| func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { | ||||
| 	redirect := req.Header.Get("X-Auth-Request-Redirect") | ||||
| 	if p.IsValidRedirect(redirect) { | ||||
| 		return redirect | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // getXForwardedHeadersRedirect handles these GetAppRedirect strategies:
 | ||||
| // - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
 | ||||
| // - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
 | ||||
| func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { | ||||
| 	if !isForwardedRequest(req) { | ||||
| 		return "" | ||||
| 	} | ||||
| 
 | ||||
| 	uri := requestutil.GetRequestURI(req) | ||||
| 	if p.hasProxyPrefix(uri) { | ||||
| 		uri = "/" | ||||
| 	} | ||||
| 
 | ||||
| 	redirect := fmt.Sprintf( | ||||
| 		"%s://%s%s", | ||||
| 		requestutil.GetRequestProto(req), | ||||
| 		requestutil.GetRequestHost(req), | ||||
| 		uri, | ||||
| 	) | ||||
| 
 | ||||
| 	if p.IsValidRedirect(redirect) { | ||||
| 		return redirect | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // getURIRedirect handles these GetAppRedirect strategies:
 | ||||
| // - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
 | ||||
| // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
 | ||||
| // - `/`
 | ||||
| func (p *OAuthProxy) getURIRedirect(req *http.Request) string { | ||||
| 	redirect := requestutil.GetRequestURI(req) | ||||
| 	if !p.IsValidRedirect(redirect) { | ||||
| 		redirect = req.URL.RequestURI() | ||||
| 	} | ||||
| 
 | ||||
| 	if p.hasProxyPrefix(redirect) { | ||||
| 		return "/" | ||||
| 	} | ||||
| 	return redirect | ||||
| } | ||||
| 
 | ||||
| // splitHostPort separates host and port. If the port is not valid, it returns
 | ||||
| // the entire input as host, and it doesn't check the validity of the host.
 | ||||
| // Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
 | ||||
| // *** taken from net/url, modified validOptionalPort() to accept ":*"
 | ||||
| func splitHostPort(hostport string) (host, port string) { | ||||
| 	host = hostport | ||||
| 
 | ||||
| 	colon := strings.LastIndexByte(host, ':') | ||||
| 	if colon != -1 && validOptionalPort(host[colon:]) { | ||||
| 		host, port = host[:colon], host[colon+1:] | ||||
| 	} | ||||
| 
 | ||||
| 	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { | ||||
| 		host = host[1 : len(host)-1] | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // validOptionalPort reports whether port is either an empty string
 | ||||
| // or matches /^:\d*$/
 | ||||
| // *** taken from net/url, modified to accept ":*"
 | ||||
| func validOptionalPort(port string) bool { | ||||
| 	if port == "" || port == ":*" { | ||||
| 		return true | ||||
| 	} | ||||
| 	if port[0] != ':' { | ||||
| 		return false | ||||
| 	} | ||||
| 	for _, b := range port[1:] { | ||||
| 		if b < '0' || b > '9' { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| // IsValidRedirect checks whether the redirect URL is whitelisted
 | ||||
| func (p *OAuthProxy) IsValidRedirect(redirect string) bool { | ||||
| 	switch { | ||||
| 	case redirect == "": | ||||
| 		// The user didn't specify a redirect, should fallback to `/`
 | ||||
| 		return false | ||||
| 	case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): | ||||
| 		return true | ||||
| 	case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): | ||||
| 		redirectURL, err := url.Parse(redirect) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) | ||||
| 			return false | ||||
| 		} | ||||
| 		redirectHostname := redirectURL.Hostname() | ||||
| 
 | ||||
| 		for _, domain := range p.whitelistDomains { | ||||
| 			domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) | ||||
| 			if domainHostname == "" { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { | ||||
| 				// the domain names match, now validate the ports
 | ||||
| 				// if the whitelisted domain's port is '*', allow all ports
 | ||||
| 				// if the whitelisted domain contains a specific port, only allow that port
 | ||||
| 				// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
 | ||||
| 				redirectPort := redirectURL.Port() | ||||
| 				if (domainPort == "*") || | ||||
| 					(domainPort == redirectPort) || | ||||
| 					(domainPort == "" && redirectPort == "") { | ||||
| 					return true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) | ||||
| 		return false | ||||
| 	default: | ||||
| 		logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // IsAllowedRequest is used to check if auth should be skipped for this request
 | ||||
| func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { | ||||
| 	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" | ||||
| 	return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req) | ||||
| } | ||||
| 
 | ||||
| // IsAllowedRoute is used to check if the request method & path is allowed without auth
 | ||||
| func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { | ||||
| 	for _, route := range p.allowedRoutes { | ||||
| 		if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
 | ||||
| var noCacheHeaders = map[string]string{ | ||||
| 	"Expires":         time.Unix(0, 0).Format(time.RFC1123), | ||||
| 	"Cache-Control":   "no-cache, no-store, must-revalidate, max-age=0", | ||||
| 	"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
 | ||||
| } | ||||
| 
 | ||||
| // prepareNoCache prepares headers for preventing browser caching.
 | ||||
| func prepareNoCache(w http.ResponseWriter) { | ||||
| 	// Set NoCache headers
 | ||||
| 	for k, v := range noCacheHeaders { | ||||
| 		w.Header().Set(k, v) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // IsTrustedIP is used to check if a request comes from a trusted client IP address.
 | ||||
| func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool { | ||||
| 	if p.trustedIPs == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) | ||||
| 		// Possibly spoofed X-Real-IP header
 | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if remoteAddr == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	return p.trustedIPs.Has(remoteAddr) | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { | ||||
| 		prepareNoCache(rw) | ||||
| 	} | ||||
| 
 | ||||
| 	switch path := req.URL.Path; { | ||||
| 	case path == p.RobotsPath: | ||||
| 		p.RobotsTxt(rw) | ||||
| 	case p.IsAllowedRequest(req): | ||||
| 		p.SkipAuthProxy(rw, req) | ||||
| 	case path == p.SignInPath: | ||||
| 		p.SignIn(rw, req) | ||||
| 	case path == p.SignOutPath: | ||||
| 		p.SignOut(rw, req) | ||||
| 	case path == p.OAuthStartPath: | ||||
| 		p.OAuthStart(rw, req) | ||||
| 	case path == p.OAuthCallbackPath: | ||||
| 		p.OAuthCallback(rw, req) | ||||
| 	case path == p.AuthOnlyPath: | ||||
| 		p.AuthOnly(rw, req) | ||||
| 	case path == p.UserInfoPath: | ||||
| 		p.UserInfo(rw, req) | ||||
| 	default: | ||||
| 		p.Proxy(rw, req) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SignIn serves a page prompting users to sign in
 | ||||
| func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||
| 	redirect, err := p.GetAppRedirect(req) | ||||
| 	redirect, err := p.getAppRedirect(req) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | ||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||
|  | @ -893,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { | |||
| 
 | ||||
| // SignOut sends a response to clear the authentication cookie
 | ||||
| func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { | ||||
| 	redirect, err := p.GetAppRedirect(req) | ||||
| 	redirect, err := p.getAppRedirect(req) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | ||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||
|  | @ -918,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { | |||
| 		return | ||||
| 	} | ||||
| 	p.SetCSRFCookie(rw, req, nonce) | ||||
| 	redirect, err := p.GetAppRedirect(req) | ||||
| 	redirect, err := p.getAppRedirect(req) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error obtaining redirect: %v", err) | ||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	redirectURI := p.GetOAuthRedirectURI(requestutil.GetRequestHost(req)) | ||||
| 	redirectURI := p.getOAuthRedirectURI(req) | ||||
| 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) | ||||
| } | ||||
| 
 | ||||
|  | @ -947,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	session, err := p.redeemCode(req.Context(), requestutil.GetRequestHost(req), req.Form.Get("code")) | ||||
| 	session, err := p.redeemCode(req) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) | ||||
| 		p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") | ||||
|  | @ -1006,6 +805,32 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	code := req.Form.Get("code") | ||||
| 	if code == "" { | ||||
| 		return nil, providers.ErrMissingCode | ||||
| 	} | ||||
| 
 | ||||
| 	redirectURI := p.getOAuthRedirectURI(req) | ||||
| 	s, err := p.provider.Redeem(req.Context(), redirectURI, code) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return s, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { | ||||
| 	var err error | ||||
| 	if s.Email == "" { | ||||
| 		s.Email, err = p.provider.GetEmailAddress(ctx, s) | ||||
| 		if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return p.provider.EnrichSession(ctx, s) | ||||
| } | ||||
| 
 | ||||
| // AuthOnly checks whether the user is currently logged in (both authentication
 | ||||
| // and optional authorization).
 | ||||
| func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { | ||||
|  | @ -1023,7 +848,7 @@ func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { | |||
| 	} | ||||
| 
 | ||||
| 	// we are authenticated
 | ||||
| 	p.addHeadersForProxying(rw, req, session) | ||||
| 	p.addHeadersForProxying(rw, session) | ||||
| 	p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		rw.WriteHeader(http.StatusAccepted) | ||||
| 	})).ServeHTTP(rw, req) | ||||
|  | @ -1041,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | |||
| 	switch err { | ||||
| 	case nil: | ||||
| 		// we are authenticated
 | ||||
| 		p.addHeadersForProxying(rw, req, session) | ||||
| 		p.addHeadersForProxying(rw, session) | ||||
| 		p.headersChain.Then(p.serveMux).ServeHTTP(rw, req) | ||||
| 	case ErrNeedsLogin: | ||||
| 		// we need to send the user to a login screen
 | ||||
| 		if isAjax(req) { | ||||
| 			// no point redirecting an AJAX request
 | ||||
| 			p.ErrorJSON(rw, http.StatusUnauthorized) | ||||
| 			p.errorJSON(rw, http.StatusUnauthorized) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
|  | @ -1066,7 +891,184 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | |||
| 		p.ErrorPage(rw, http.StatusInternalServerError, | ||||
| 			"Internal Error", "Internal Error") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
 | ||||
| var noCacheHeaders = map[string]string{ | ||||
| 	"Expires":         time.Unix(0, 0).Format(time.RFC1123), | ||||
| 	"Cache-Control":   "no-cache, no-store, must-revalidate, max-age=0", | ||||
| 	"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
 | ||||
| } | ||||
| 
 | ||||
| // prepareNoCache prepares headers for preventing browser caching.
 | ||||
| func prepareNoCache(w http.ResponseWriter) { | ||||
| 	// Set NoCache headers
 | ||||
| 	for k, v := range noCacheHeaders { | ||||
| 		w.Header().Set(k, v) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
 | ||||
| // redirect clients to once authenticated.
 | ||||
| // This is usually the OAuthProxy callback URL.
 | ||||
| func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string { | ||||
| 	// if `p.redirectURL` already has a host, return it
 | ||||
| 	if p.redirectURL.Host != "" { | ||||
| 		return p.redirectURL.String() | ||||
| 	} | ||||
| 
 | ||||
| 	// Otherwise figure out the scheme + host from the request
 | ||||
| 	rd := *p.redirectURL | ||||
| 	rd.Host = requestutil.GetRequestHost(req) | ||||
| 	rd.Scheme = requestutil.GetRequestProto(req) | ||||
| 
 | ||||
| 	// If CookieSecure is true, return `https` no matter what
 | ||||
| 	// Not all reverse proxies set X-Forwarded-Proto
 | ||||
| 	if p.CookieSecure { | ||||
| 		rd.Scheme = schemeHTTPS | ||||
| 	} | ||||
| 	return rd.String() | ||||
| } | ||||
| 
 | ||||
| // getAppRedirect determines the full URL or URI path to redirect clients to
 | ||||
| // once authenticated with the OAuthProxy
 | ||||
| // Strategy priority (first legal result is used):
 | ||||
| // - `rd` querysting parameter
 | ||||
| // - `X-Auth-Request-Redirect` header
 | ||||
| // - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
 | ||||
| // - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
 | ||||
| // - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
 | ||||
| // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
 | ||||
| // - `/`
 | ||||
| func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) { | ||||
| 	err := req.ParseForm() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	// These redirect getter functions are strategies ordered by priority
 | ||||
| 	// for figuring out the redirect URL.
 | ||||
| 	type redirectGetter func(req *http.Request) string | ||||
| 	for _, rdGetter := range []redirectGetter{ | ||||
| 		p.getRdQuerystringRedirect, | ||||
| 		p.getXAuthRequestRedirect, | ||||
| 		p.getXForwardedHeadersRedirect, | ||||
| 		p.getURIRedirect, | ||||
| 	} { | ||||
| 		if redirect := rdGetter(req); redirect != "" { | ||||
| 			return redirect, nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return "/", nil | ||||
| } | ||||
| 
 | ||||
| func isForwardedRequest(req *http.Request) bool { | ||||
| 	return requestutil.IsProxied(req) && | ||||
| 		req.Host != requestutil.GetRequestHost(req) | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) hasProxyPrefix(path string) bool { | ||||
| 	return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) | ||||
| } | ||||
| 
 | ||||
| // getRdQuerystringRedirect handles this getAppRedirect strategy:
 | ||||
| // - `rd` querysting parameter
 | ||||
| func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { | ||||
| 	redirect := req.Form.Get("rd") | ||||
| 	if p.IsValidRedirect(redirect) { | ||||
| 		return redirect | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // getXAuthRequestRedirect handles this getAppRedirect strategy:
 | ||||
| // - `X-Auth-Request-Redirect` Header
 | ||||
| func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { | ||||
| 	redirect := req.Header.Get("X-Auth-Request-Redirect") | ||||
| 	if p.IsValidRedirect(redirect) { | ||||
| 		return redirect | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // getXForwardedHeadersRedirect handles these getAppRedirect strategies:
 | ||||
| // - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
 | ||||
| // - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
 | ||||
| func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { | ||||
| 	if !isForwardedRequest(req) { | ||||
| 		return "" | ||||
| 	} | ||||
| 
 | ||||
| 	uri := requestutil.GetRequestURI(req) | ||||
| 	if p.hasProxyPrefix(uri) { | ||||
| 		uri = "/" | ||||
| 	} | ||||
| 
 | ||||
| 	redirect := fmt.Sprintf( | ||||
| 		"%s://%s%s", | ||||
| 		requestutil.GetRequestProto(req), | ||||
| 		requestutil.GetRequestHost(req), | ||||
| 		uri, | ||||
| 	) | ||||
| 
 | ||||
| 	if p.IsValidRedirect(redirect) { | ||||
| 		return redirect | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // getURIRedirect handles these getAppRedirect strategies:
 | ||||
| // - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
 | ||||
| // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
 | ||||
| // - `/`
 | ||||
| func (p *OAuthProxy) getURIRedirect(req *http.Request) string { | ||||
| 	redirect := requestutil.GetRequestURI(req) | ||||
| 	if !p.IsValidRedirect(redirect) { | ||||
| 		redirect = req.URL.RequestURI() | ||||
| 	} | ||||
| 
 | ||||
| 	if p.hasProxyPrefix(redirect) { | ||||
| 		return "/" | ||||
| 	} | ||||
| 	return redirect | ||||
| } | ||||
| 
 | ||||
| // splitHostPort separates host and port. If the port is not valid, it returns
 | ||||
| // the entire input as host, and it doesn't check the validity of the host.
 | ||||
| // Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
 | ||||
| // *** taken from net/url, modified validOptionalPort() to accept ":*"
 | ||||
| func splitHostPort(hostport string) (host, port string) { | ||||
| 	host = hostport | ||||
| 
 | ||||
| 	colon := strings.LastIndexByte(host, ':') | ||||
| 	if colon != -1 && validOptionalPort(host[colon:]) { | ||||
| 		host, port = host[:colon], host[colon+1:] | ||||
| 	} | ||||
| 
 | ||||
| 	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { | ||||
| 		host = host[1 : len(host)-1] | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // validOptionalPort reports whether port is either an empty string
 | ||||
| // or matches /^:\d*$/
 | ||||
| // *** taken from net/url, modified to accept ":*"
 | ||||
| func validOptionalPort(port string) bool { | ||||
| 	if port == "" || port == ":*" { | ||||
| 		return true | ||||
| 	} | ||||
| 	if port[0] != ':' { | ||||
| 		return false | ||||
| 	} | ||||
| 	for _, b := range port[1:] { | ||||
| 		if b < '0' || b > '9' { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
 | ||||
|  | @ -1153,7 +1155,7 @@ func extractAllowedGroups(req *http.Request) map[string]struct{} { | |||
| } | ||||
| 
 | ||||
| // addHeadersForProxying adds the appropriate headers the request / response for proxying
 | ||||
| func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) { | ||||
| func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) { | ||||
| 	if session.Email == "" { | ||||
| 		rw.Header().Set("GAP-Auth", session.User) | ||||
| 	} else { | ||||
|  | @ -1181,8 +1183,8 @@ func isAjax(req *http.Request) bool { | |||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // ErrorJSON returns the error code with an application/json mime type
 | ||||
| func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { | ||||
| // errorJSON returns the error code with an application/json mime type
 | ||||
| func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) { | ||||
| 	rw.Header().Set("Content-Type", applicationJSON) | ||||
| 	rw.WriteHeader(code) | ||||
| } | ||||
|  |  | |||
|  | @ -415,8 +415,9 @@ func Test_redeemCode(t *testing.T) { | |||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = proxy.redeemCode(context.Background(), "www.example.com", "") | ||||
| 	assert.Error(t, err) | ||||
| 	req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
| 	_, err = proxy.redeemCode(req) | ||||
| 	assert.Equal(t, providers.ErrMissingCode, err) | ||||
| } | ||||
| 
 | ||||
| func Test_enrichSession(t *testing.T) { | ||||
|  | @ -1749,7 +1750,7 @@ func TestRequestSignature(t *testing.T) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGetRedirect(t *testing.T) { | ||||
| func Test_getAppRedirect(t *testing.T) { | ||||
| 	opts := baseTestOptions() | ||||
| 	opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443") | ||||
| 	err := validation.Validate(opts) | ||||
|  | @ -1900,7 +1901,7 @@ func TestGetRedirect(t *testing.T) { | |||
| 			req = middleware.AddRequestScope(req, &middleware.RequestScope{ | ||||
| 				ReverseProxy: tt.reverseProxy, | ||||
| 			}) | ||||
| 			redirect, err := proxy.GetAppRedirect(req) | ||||
| 			redirect, err := proxy.getAppRedirect(req) | ||||
| 
 | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, tt.expectedRedirect, redirect) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue