Self code review changes
This commit is contained in:
		
							parent
							
								
									8ec025f536
								
							
						
					
					
						commit
						37c415b889
					
				
							
								
								
									
										4
									
								
								main.go
								
								
								
								
							
							
						
						
									
										4
									
								
								main.go
								
								
								
								
							|  | @ -20,7 +20,7 @@ func main() { | ||||||
| 	flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError) | 	flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError) | ||||||
| 
 | 
 | ||||||
| 	emailDomains := StringArray{} | 	emailDomains := StringArray{} | ||||||
| 	whitelistandardomains := StringArray{} | 	whitelistDomains := StringArray{} | ||||||
| 	upstreams := StringArray{} | 	upstreams := StringArray{} | ||||||
| 	skipAuthRegex := StringArray{} | 	skipAuthRegex := StringArray{} | ||||||
| 	googleGroups := StringArray{} | 	googleGroups := StringArray{} | ||||||
|  | @ -49,7 +49,7 @@ func main() { | ||||||
| 	flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses") | 	flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses") | ||||||
| 
 | 
 | ||||||
| 	flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") | 	flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") | ||||||
| 	flagSet.Var(&whitelistandardomains, "whitelist-domain", "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)") | 	flagSet.Var(&whitelistDomains, "whitelist-domain", "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)") | ||||||
| 	flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") | 	flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") | ||||||
| 	flagSet.String("github-org", "", "restrict logins to members of this organisation") | 	flagSet.String("github-org", "", "restrict logins to members of this organisation") | ||||||
| 	flagSet.String("github-team", "", "restrict logins to members of this team") | 	flagSet.String("github-team", "", "restrict logins to members of this team") | ||||||
|  |  | ||||||
|  | @ -604,10 +604,10 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st | ||||||
| 	} | 	} | ||||||
| 	// check auth
 | 	// check auth
 | ||||||
| 	if p.HtpasswdFile.Validate(user, passwd) { | 	if p.HtpasswdFile.Validate(user, passwd) { | ||||||
| 		logger.PrintAuthf(user, req, logger.AuthSuccess, "Successful authentication via HtpasswdFile") | 		logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via HtpasswdFile") | ||||||
| 		return user, true | 		return user, true | ||||||
| 	} | 	} | ||||||
| 	logger.PrintAuthf(user, req, logger.AuthFailure, "Failed authentication via HtpasswdFile; unauthorized") | 	logger.PrintAuthf(user, req, logger.AuthFailure, "Invalid authentication via HtpasswdFile") | ||||||
| 	return "", false | 	return "", false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -755,27 +755,27 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	// finish the oauth cycle
 | 	// finish the oauth cycle
 | ||||||
| 	err := req.ParseForm() | 	err := req.ParseForm() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("Error while parsing OAuth callback: %s" + err.Error()) | 		logger.Printf("Error while parsing OAuth2 callback: %s" + err.Error()) | ||||||
| 		p.ErrorPage(rw, 500, "Internal Error", err.Error()) | 		p.ErrorPage(rw, 500, "Internal Error", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	errorString := req.Form.Get("error") | 	errorString := req.Form.Get("error") | ||||||
| 	if errorString != "" { | 	if errorString != "" { | ||||||
| 		logger.Printf("Error while parsing OAuth callback: %s ", errorString) | 		logger.Printf("Error while parsing OAuth2 callback: %s ", errorString) | ||||||
| 		p.ErrorPage(rw, 403, "Permission Denied", errorString) | 		p.ErrorPage(rw, 403, "Permission Denied", errorString) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	session, err := p.redeemCode(req.Host, req.Form.Get("code")) | 	session, err := p.redeemCode(req.Host, req.Form.Get("code")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("Error while parsing OAuth callback: %s ", errorString) | 		logger.Printf("Error redeeming code during OAuth2 callback: %s ", errorString) | ||||||
| 		p.ErrorPage(rw, 500, "Internal Error", "Internal Error") | 		p.ErrorPage(rw, 500, "Internal Error", "Internal Error") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	s := strings.SplitN(req.Form.Get("state"), ":", 2) | 	s := strings.SplitN(req.Form.Get("state"), ":", 2) | ||||||
| 	if len(s) != 2 { | 	if len(s) != 2 { | ||||||
| 		logger.Printf("Error while parsing OAuth state; invalid length") | 		logger.Printf("Error while parsing OAuth2 state; invalid length") | ||||||
| 		p.ErrorPage(rw, 500, "Internal Error", "Invalid State") | 		p.ErrorPage(rw, 500, "Internal Error", "Invalid State") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | @ -783,13 +783,13 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	redirect := s[1] | 	redirect := s[1] | ||||||
| 	c, err := req.Cookie(p.CSRFCookieName) | 	c, err := req.Cookie(p.CSRFCookieName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Failed authentication via oauth2; unable too obtain CSRF cookie") | 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2; unable too obtain CSRF cookie") | ||||||
| 		p.ErrorPage(rw, 403, "Permission Denied", err.Error()) | 		p.ErrorPage(rw, 403, "Permission Denied", err.Error()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	p.ClearCSRFCookie(rw, req) | 	p.ClearCSRFCookie(rw, req) | ||||||
| 	if c.Value != nonce { | 	if c.Value != nonce { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Failed authentication via oauth2; csrf token mismatch, potential attack") | 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2; csrf token mismatch, potential attack") | ||||||
| 		p.ErrorPage(rw, 403, "Permission Denied", "csrf failed") | 		p.ErrorPage(rw, 403, "Permission Denied", "csrf failed") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | @ -800,7 +800,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 
 | 
 | ||||||
| 	// set cookie, or deny
 | 	// set cookie, or deny
 | ||||||
| 	if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { | 	if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Successful authentication via oauth2; %s", session) | 		logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2; %s", session) | ||||||
| 		err := p.SaveSession(rw, req, session) | 		err := p.SaveSession(rw, req, session) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Printf("%s %s", remoteAddr, err) | 			logger.Printf("%s %s", remoteAddr, err) | ||||||
|  | @ -809,7 +809,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		} | 		} | ||||||
| 		http.Redirect(rw, req, redirect, 302) | 		http.Redirect(rw, req, redirect, 302) | ||||||
| 	} else { | 	} else { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Failed authentication via oauth2; unauthorized") | 		logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Invalid authentication via OAuth2; unauthorized") | ||||||
| 		p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") | 		p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | @ -885,7 +885,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if session != nil && session.Email != "" && !p.Validator(session.Email) { | 	if session != nil && session.Email != "" && !p.Validator(session.Email) { | ||||||
| 		logger.Printf(session.Email, req, logger.AuthFailure, "Failed authentication via session; removing session %s", session) | 		logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session; removing session %s", session) | ||||||
| 		session = nil | 		session = nil | ||||||
| 		saveSession = false | 		saveSession = false | ||||||
| 		clearSession = true | 		clearSession = true | ||||||
|  | @ -979,10 +979,10 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, | ||||||
| 		return nil, fmt.Errorf("invalid format %s", b) | 		return nil, fmt.Errorf("invalid format %s", b) | ||||||
| 	} | 	} | ||||||
| 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | ||||||
| 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Successful authentication via basic auth") | 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||||
| 		return &providers.SessionState{User: pair[0]}, nil | 		return &providers.SessionState{User: pair[0]}, nil | ||||||
| 	} | 	} | ||||||
| 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Failed authentication via basic auth; not in Htpasswd file") | 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth; not in Htpasswd File") | ||||||
| 	return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) | 	return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue