Merge pull request #797 from grnhse/refactor-provider-authz
Centralize Provider authorization interface method
This commit is contained in:
		
						commit
						c377466411
					
				|  | @ -6,6 +6,11 @@ | ||||||
| 
 | 
 | ||||||
| - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. | - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. | ||||||
| - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. | - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. | ||||||
|  | - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this | ||||||
|  |   - Either `--google-group` or the new `--allowed-group` will work for Google now (`--google-group` will be used if both are set) | ||||||
|  |   - Group membership lists will be passed to the backend with the `X-Forwarded-Groups` header | ||||||
|  |   - If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately. | ||||||
|  |       - Previously, group membership was only checked on session creation and refresh. | ||||||
| - [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex` | - [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex` | ||||||
|   - We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version. |   - We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version. | ||||||
|   - If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly) |   - If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly) | ||||||
|  | @ -18,6 +23,9 @@ | ||||||
| ## Breaking Changes | ## Breaking Changes | ||||||
| 
 | 
 | ||||||
| - [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google". | - [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google". | ||||||
|  | - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow | ||||||
|  |   - If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately. | ||||||
|  |     - Previously, group membership was only checked on session creation and refresh. | ||||||
| - [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass | - [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass | ||||||
| - [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation. | - [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation. | ||||||
|   - You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy |   - You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy | ||||||
|  | @ -40,6 +48,7 @@ | ||||||
| - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves) | - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves) | ||||||
| - [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves) | - [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves) | ||||||
| - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed) | - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed) | ||||||
|  | - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Create universal Authorization behavior across providers (@NickMeves) | ||||||
| - [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed) | - [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed) | ||||||
| - [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock) | - [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock) | ||||||
| - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) | - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) | ||||||
|  |  | ||||||
|  | @ -42,6 +42,9 @@ var ( | ||||||
| 	// ErrNeedsLogin means the user should be redirected to the login page
 | 	// ErrNeedsLogin means the user should be redirected to the login page
 | ||||||
| 	ErrNeedsLogin = errors.New("redirect to login page") | 	ErrNeedsLogin = errors.New("redirect to login page") | ||||||
| 
 | 
 | ||||||
|  | 	// ErrAccessDenied means the user should receive a 401 Unauthorized response
 | ||||||
|  | 	ErrAccessDenied = errors.New("access denied") | ||||||
|  | 
 | ||||||
| 	// Used to check final redirects are not susceptible to open redirects.
 | 	// Used to check final redirects are not susceptible to open redirects.
 | ||||||
| 	// Matches //, /\ and both of these with whitespace in between (eg / / or / \).
 | 	// Matches //, /\ and both of these with whitespace in between (eg / / or / \).
 | ||||||
| 	invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`) | 	invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`) | ||||||
|  | @ -105,7 +108,6 @@ type OAuthProxy struct { | ||||||
| 	trustedIPs              *ip.NetSet | 	trustedIPs              *ip.NetSet | ||||||
| 	Banner                  string | 	Banner                  string | ||||||
| 	Footer                  string | 	Footer                  string | ||||||
| 	AllowedGroups           []string |  | ||||||
| 
 | 
 | ||||||
| 	sessionChain alice.Chain | 	sessionChain alice.Chain | ||||||
| 	headersChain alice.Chain | 	headersChain alice.Chain | ||||||
|  | @ -219,7 +221,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		Banner:                  opts.Banner, | 		Banner:                  opts.Banner, | ||||||
| 		Footer:                  opts.Footer, | 		Footer:                  opts.Footer, | ||||||
| 		SignInMessage:           buildSignInMessage(opts), | 		SignInMessage:           buildSignInMessage(opts), | ||||||
| 		AllowedGroups:           opts.AllowedGroups, |  | ||||||
| 
 | 
 | ||||||
| 		basicAuthValidator:  basicAuthValidator, | 		basicAuthValidator:  basicAuthValidator, | ||||||
| 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | ||||||
|  | @ -396,7 +397,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string { | ||||||
| 
 | 
 | ||||||
| func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { | func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, errors.New("missing code") | 		return nil, providers.ErrMissingCode | ||||||
| 	} | 	} | ||||||
| 	redirectURI := p.GetRedirectURI(host) | 	redirectURI := p.GetRedirectURI(host) | ||||||
| 	s, err := p.provider.Redeem(ctx, redirectURI, code) | 	s, err := p.provider.Redeem(ctx, redirectURI, code) | ||||||
|  | @ -909,11 +910,15 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// set cookie, or deny
 | 	// set cookie, or deny
 | ||||||
| 	if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { | 	authorized, err := p.provider.Authorize(req.Context(), session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Error with authorization: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if p.Validator(session.Email) && authorized { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) | 		logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) | ||||||
| 		err := p.SaveSession(rw, req, session) | 		err := p.SaveSession(rw, req, session) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Printf("Error saving session state for %s: %v", remoteAddr, err) | 			logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) | ||||||
| 			p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | 			p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | @ -967,6 +972,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||||
| 			p.SignInPage(rw, req, http.StatusForbidden) | 			p.SignInPage(rw, req, http.StatusForbidden) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 	case ErrAccessDenied: | ||||||
|  | 		p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized") | ||||||
|  | 
 | ||||||
| 	default: | 	default: | ||||||
| 		// unknown error
 | 		// unknown error
 | ||||||
| 		logger.Errorf("Unexpected internal error: %v", err) | 		logger.Errorf("Unexpected internal error: %v", err) | ||||||
|  | @ -977,7 +985,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
 | // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
 | ||||||
| // Returns nil, ErrNeedsLogin if user needs to login.
 | // Returns:
 | ||||||
|  | // - `nil, ErrNeedsLogin` if user needs to login.
 | ||||||
|  | // - `nil, ErrAccessDenied` if the authenticated user is not authorized
 | ||||||
| // Set-Cookie headers may be set on the response as a side-effect of calling this method.
 | // Set-Cookie headers may be set on the response as a side-effect of calling this method.
 | ||||||
| func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
| 	var session *sessionsapi.SessionState | 	var session *sessionsapi.SessionState | ||||||
|  | @ -991,17 +1001,20 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R | ||||||
| 		return nil, ErrNeedsLogin | 		return nil, ErrNeedsLogin | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email) | 	invalidEmail := session.Email != "" && !p.Validator(session.Email) | ||||||
| 	invalidGroups := session != nil && !p.validateGroups(session.Groups) | 	authorized, err := p.provider.Authorize(req.Context(), session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Error with authorization: %v", err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	if invalidEmail || invalidGroups { | 	if invalidEmail || !authorized { | ||||||
| 		logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) | 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authorization via session: removing session %s", session) | ||||||
| 		// Invalid session, clear it
 | 		// Invalid session, clear it
 | ||||||
| 		err := p.ClearSessionCookie(rw, req) | 		err := p.ClearSessionCookie(rw, req) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Printf("Error clearing session cookie: %v", err) | 			logger.Errorf("Error clearing session cookie: %v", err) | ||||||
| 		} | 		} | ||||||
| 		return nil, ErrNeedsLogin | 		return nil, ErrAccessDenied | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return session, nil | 	return session, nil | ||||||
|  | @ -1033,23 +1046,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { | ||||||
| 	rw.Header().Set("Content-Type", applicationJSON) | 	rw.Header().Set("Content-Type", applicationJSON) | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) validateGroups(groups []string) bool { |  | ||||||
| 	if len(p.AllowedGroups) == 0 { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	allowedGroups := map[string]struct{}{} |  | ||||||
| 
 |  | ||||||
| 	for _, group := range p.AllowedGroups { |  | ||||||
| 		allowedGroups[group] = struct{}{} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, group := range groups { |  | ||||||
| 		if _, ok := allowedGroups[group]; ok { |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   opts.providerValidateCookieResponse, | 		ValidToken:   opts.providerValidateCookieResponse, | ||||||
| 	} | 	} | ||||||
|  | 	pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups) | ||||||
| 
 | 
 | ||||||
| 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | ||||||
| 	// access_token validation.
 | 	// access_token validation.
 | ||||||
|  | @ -1284,6 +1286,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   true, | 		ValidToken:   true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -1376,6 +1379,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   true, | 		ValidToken:   true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -1455,6 +1459,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   true, | 		ValidToken:   true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | ||||||
| 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | ||||||
| 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | ||||||
| 
 | 
 | ||||||
|  | 	p.SetAllowedGroups(o.AllowedGroups) | ||||||
|  | 
 | ||||||
| 	provider := providers.New(o.ProviderType, p) | 	provider := providers.New(o.ProviderType, p) | ||||||
| 	if provider == nil { | 	if provider == nil { | ||||||
| 		msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) | 		msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) | ||||||
|  | @ -255,7 +257,13 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) | 				msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) | ||||||
| 			} else { | 			} else { | ||||||
| 				p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file) | 				groups := o.AllowedGroups | ||||||
|  | 				// Backwards compatibility with `--google-group` option
 | ||||||
|  | 				if len(o.GoogleGroups) > 0 { | ||||||
|  | 					groups = o.GoogleGroups | ||||||
|  | 					p.SetAllowedGroups(groups) | ||||||
|  | 				} | ||||||
|  | 				p.SetGroupRestriction(groups, o.GoogleAdminEmail, file) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	case *providers.BitbucketProvider: | 	case *providers.BitbucketProvider: | ||||||
|  |  | ||||||
|  | @ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		return nil, ErrMissingCode | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 	clientSecret, err := p.GetClientSecret() | 	clientSecret, err := p.GetClientSecret() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	params := url.Values{} | 	params := url.Values{} | ||||||
|  | @ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() | 	created := time.Now() | ||||||
| 	expires := time.Unix(jsonResponse.ExpiresOn, 0) | 	expires := time.Unix(jsonResponse.ExpiresOn, 0) | ||||||
| 	s = &sessions.SessionState{ | 
 | ||||||
|  | 	return &sessions.SessionState{ | ||||||
| 		AccessToken:  jsonResponse.AccessToken, | 		AccessToken:  jsonResponse.AccessToken, | ||||||
| 		IDToken:      jsonResponse.IDToken, | 		IDToken:      jsonResponse.IDToken, | ||||||
| 		CreatedAt:    &created, | 		CreatedAt:    &created, | ||||||
| 		ExpiresOn:    &expires, | 		ExpiresOn:    &expires, | ||||||
| 		RefreshToken: jsonResponse.RefreshToken, | 		RefreshToken: jsonResponse.RefreshToken, | ||||||
| 	} | 	}, nil | ||||||
| 	return |  | ||||||
| 
 |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
|  |  | ||||||
|  | @ -25,10 +25,16 @@ import ( | ||||||
| // GoogleProvider represents an Google based Identity Provider
 | // GoogleProvider represents an Google based Identity Provider
 | ||||||
| type GoogleProvider struct { | type GoogleProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
|  | 
 | ||||||
| 	RedeemRefreshURL *url.URL | 	RedeemRefreshURL *url.URL | ||||||
| 	// GroupValidator is a function that determines if the passed email is in
 | 
 | ||||||
| 	// the configured Google group.
 | 	// groupValidator is a function that determines if the user in the passed
 | ||||||
| 	GroupValidator func(string) bool | 	// session is a member of any of the configured Google groups.
 | ||||||
|  | 	//
 | ||||||
|  | 	// This hits the Google API for each group, so it is called on Redeem &
 | ||||||
|  | 	// Refresh. `Authorize` uses the results of this saved in `session.Groups`
 | ||||||
|  | 	// Since it is called on every request.
 | ||||||
|  | 	groupValidator func(*sessions.SessionState) bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| var _ Provider = (*GoogleProvider)(nil) | var _ Provider = (*GoogleProvider)(nil) | ||||||
|  | @ -84,9 +90,9 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { | ||||||
| 	}) | 	}) | ||||||
| 	return &GoogleProvider{ | 	return &GoogleProvider{ | ||||||
| 		ProviderData: p, | 		ProviderData: p, | ||||||
| 		// Set a default GroupValidator to just always return valid (true), it will
 | 		// Set a default groupValidator to just always return valid (true), it will
 | ||||||
| 		// be overwritten if we configured a Google group restriction.
 | 		// be overwritten if we configured a Google group restriction.
 | ||||||
| 		GroupValidator: func(email string) bool { | 		groupValidator: func(*sessions.SessionState) bool { | ||||||
| 			return true | 			return true | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  | @ -118,14 +124,13 @@ func claimsFromIDToken(idToken string) (*claims, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		return nil, ErrMissingCode | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 	clientSecret, err := p.GetClientSecret() | 	clientSecret, err := p.GetClientSecret() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	params := url.Values{} | 	params := url.Values{} | ||||||
|  | @ -155,12 +160,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | ||||||
| 
 | 
 | ||||||
| 	c, err := claimsFromIDToken(jsonResponse.IDToken) | 	c, err := claimsFromIDToken(jsonResponse.IDToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() | 	created := time.Now() | ||||||
| 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) | 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) | ||||||
| 	s = &sessions.SessionState{ | 
 | ||||||
|  | 	return &sessions.SessionState{ | ||||||
| 		AccessToken:  jsonResponse.AccessToken, | 		AccessToken:  jsonResponse.AccessToken, | ||||||
| 		IDToken:      jsonResponse.IDToken, | 		IDToken:      jsonResponse.IDToken, | ||||||
| 		CreatedAt:    &created, | 		CreatedAt:    &created, | ||||||
|  | @ -168,18 +174,40 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | ||||||
| 		RefreshToken: jsonResponse.RefreshToken, | 		RefreshToken: jsonResponse.RefreshToken, | ||||||
| 		Email:        c.Email, | 		Email:        c.Email, | ||||||
| 		User:         c.Subject, | 		User:         c.Subject, | ||||||
|  | 	}, nil | ||||||
| } | } | ||||||
| 	return | 
 | ||||||
|  | // EnrichSessionState checks the listed Google Groups configured and adds any
 | ||||||
|  | // that the user is a member of to session.Groups.
 | ||||||
|  | func (p *GoogleProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error { | ||||||
|  | 	// TODO (@NickMeves) - Move to pure EnrichSessionState logic and stop
 | ||||||
|  | 	// reusing legacy `groupValidator`.
 | ||||||
|  | 	//
 | ||||||
|  | 	// This is called here to get the validator to do the `session.Groups`
 | ||||||
|  | 	// populating logic.
 | ||||||
|  | 	p.groupValidator(s) | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SetGroupRestriction configures the GoogleProvider to restrict access to the
 | // SetGroupRestriction configures the GoogleProvider to restrict access to the
 | ||||||
| // specified group(s). AdminEmail has to be an administrative email on the domain that is
 | // specified group(s). AdminEmail has to be an administrative email on the domain that is
 | ||||||
| // checked. CredentialsFile is the path to a json file containing a Google service
 | // checked. CredentialsFile is the path to a json file containing a Google service
 | ||||||
| // account credentials.
 | // account credentials.
 | ||||||
|  | //
 | ||||||
|  | // TODO (@NickMeves) - Unit Test this OR refactor away from groupValidator func
 | ||||||
| func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { | func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { | ||||||
| 	adminService := getAdminService(adminEmail, credentialsReader) | 	adminService := getAdminService(adminEmail, credentialsReader) | ||||||
| 	p.GroupValidator = func(email string) bool { | 	p.groupValidator = func(s *sessions.SessionState) bool { | ||||||
| 		return userInGroup(adminService, groups, email) | 		// Reset our saved Groups in case membership changed
 | ||||||
|  | 		// This is used by `Authorize` on every request
 | ||||||
|  | 		s.Groups = make([]string, 0, len(groups)) | ||||||
|  | 		for _, group := range groups { | ||||||
|  | 			if userInGroup(adminService, group, s.Email) { | ||||||
|  | 				s.Groups = append(s.Groups, group) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return len(s.Groups) > 0 | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -203,12 +231,14 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv | ||||||
| 	return adminService | 	return adminService | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func userInGroup(service *admin.Service, groups []string, email string) bool { | func userInGroup(service *admin.Service, group string, email string) bool { | ||||||
| 	for _, group := range groups { |  | ||||||
| 	// Use the HasMember API to checking for the user's presence in each group or nested subgroups
 | 	// Use the HasMember API to checking for the user's presence in each group or nested subgroups
 | ||||||
| 	req := service.Members.HasMember(group, email) | 	req := service.Members.HasMember(group, email) | ||||||
| 	r, err := req.Do() | 	r, err := req.Do() | ||||||
| 		if err != nil { | 	if err == nil { | ||||||
|  | 		return r.IsMember | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	gerr, ok := err.(*googleapi.Error) | 	gerr, ok := err.(*googleapi.Error) | ||||||
| 	switch { | 	switch { | ||||||
| 	case ok && gerr.Code == 404: | 	case ok && gerr.Code == 404: | ||||||
|  | @ -220,10 +250,9 @@ func userInGroup(service *admin.Service, groups []string, email string) bool { | ||||||
| 		// from the HasMember API. In that case, attempt to query the member object directly from the group.
 | 		// from the HasMember API. In that case, attempt to query the member object directly from the group.
 | ||||||
| 		req := service.Members.Get(group, email) | 		req := service.Members.Get(group, email) | ||||||
| 		r, err := req.Do() | 		r, err := req.Do() | ||||||
| 
 |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group) | 			logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group) | ||||||
| 					continue | 			return false | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// If the non-domain user is found within the group, still verify that they are "ACTIVE".
 | 		// If the non-domain user is found within the group, still verify that they are "ACTIVE".
 | ||||||
|  | @ -234,21 +263,9 @@ func userInGroup(service *admin.Service, groups []string, email string) bool { | ||||||
| 	default: | 	default: | ||||||
| 		logger.Errorf("error checking group membership: %v", err) | 		logger.Errorf("error checking group membership: %v", err) | ||||||
| 	} | 	} | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if r.IsMember { |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateGroup validates that the provided email exists in the configured Google
 |  | ||||||
| // group(s).
 |  | ||||||
| func (p *GoogleProvider) ValidateGroup(email string) bool { |  | ||||||
| 	return p.GroupValidator(email) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | // RefreshToken to fetch a new ID token if required
 | ||||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | @ -261,8 +278,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// TODO (@NickMeves) - Align Group authorization needs with other providers'
 | ||||||
|  | 	// behavior in the `RefreshSession` case.
 | ||||||
|  | 	//
 | ||||||
| 	// re-check that the user is in the proper google group(s)
 | 	// re-check that the user is in the proper google group(s)
 | ||||||
| 	if !p.ValidateGroup(s.Email) { | 	if !p.groupValidator(s) { | ||||||
| 		return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) | 		return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -10,6 +10,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	admin "google.golang.org/api/admin/directory/v1" | 	admin "google.golang.org/api/admin/directory/v1" | ||||||
|  | @ -109,21 +110,50 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { | ||||||
| 	assert.Equal(t, "refresh12345", session.RefreshToken) | 	assert.Equal(t, "refresh12345", session.RefreshToken) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestGoogleProviderValidateGroup(t *testing.T) { | func TestGoogleProviderGroupValidator(t *testing.T) { | ||||||
| 	p := newGoogleProvider() | 	const sessionEmail = "michael.bland@gsa.gov" | ||||||
| 	p.GroupValidator = func(email string) bool { |  | ||||||
| 		return email == "michael.bland@gsa.gov" |  | ||||||
| 	} |  | ||||||
| 	assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) |  | ||||||
| 	p.GroupValidator = func(email string) bool { |  | ||||||
| 		return email != "michael.bland@gsa.gov" |  | ||||||
| 	} |  | ||||||
| 	assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov")) |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| func TestGoogleProviderWithoutValidateGroup(t *testing.T) { | 	testCases := map[string]struct { | ||||||
|  | 		session       *sessions.SessionState | ||||||
|  | 		validatorFunc func(*sessions.SessionState) bool | ||||||
|  | 		expectedAuthZ bool | ||||||
|  | 	}{ | ||||||
|  | 		"Email is authorized with groupValidator": { | ||||||
|  | 			session: &sessions.SessionState{ | ||||||
|  | 				Email: sessionEmail, | ||||||
|  | 			}, | ||||||
|  | 			validatorFunc: func(s *sessions.SessionState) bool { | ||||||
|  | 				return s.Email == sessionEmail | ||||||
|  | 			}, | ||||||
|  | 			expectedAuthZ: true, | ||||||
|  | 		}, | ||||||
|  | 		"Email is denied with groupValidator": { | ||||||
|  | 			session: &sessions.SessionState{ | ||||||
|  | 				Email: sessionEmail, | ||||||
|  | 			}, | ||||||
|  | 			validatorFunc: func(s *sessions.SessionState) bool { | ||||||
|  | 				return s.Email != sessionEmail | ||||||
|  | 			}, | ||||||
|  | 			expectedAuthZ: false, | ||||||
|  | 		}, | ||||||
|  | 		"Default does no authorization checks": { | ||||||
|  | 			session: &sessions.SessionState{ | ||||||
|  | 				Email: sessionEmail, | ||||||
|  | 			}, | ||||||
|  | 			validatorFunc: nil, | ||||||
|  | 			expectedAuthZ: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for name, tc := range testCases { | ||||||
|  | 		t.Run(name, func(t *testing.T) { | ||||||
|  | 			g := NewWithT(t) | ||||||
| 			p := newGoogleProvider() | 			p := newGoogleProvider() | ||||||
| 	assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) | 			if tc.validatorFunc != nil { | ||||||
|  | 				p.groupValidator = tc.validatorFunc | ||||||
|  | 			} | ||||||
|  | 			g.Expect(p.groupValidator(tc.session)).To(Equal(tc.expectedAuthZ)) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
|  | @ -196,7 +226,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestGoogleProviderUserInGroup(t *testing.T) { | func TestGoogleProvider_userInGroup(t *testing.T) { | ||||||
| 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" { | 		if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" { | ||||||
| 			fmt.Fprintln(w, `{"isMember": true}`) | 			fmt.Fprintln(w, `{"isMember": true}`) | ||||||
|  | @ -233,18 +263,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) { | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 
 | 
 | ||||||
| 	service, err := admin.NewService(ctx, option.WithHTTPClient(client)) | 	service, err := admin.NewService(ctx, option.WithHTTPClient(client)) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
| 	service.BasePath = ts.URL | 	service.BasePath = ts.URL | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 
 | 
 | ||||||
| 	result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com") | 	result := userInGroup(service, "group@example.com", "member-in-domain@example.com") | ||||||
| 	assert.True(t, result) | 	assert.True(t, result) | ||||||
| 
 | 
 | ||||||
| 	result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com") | 	result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com") | ||||||
| 	assert.True(t, result) | 	assert.True(t, result) | ||||||
| 
 | 
 | ||||||
| 	result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com") | 	result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com") | ||||||
| 	assert.False(t, result) | 	assert.False(t, result) | ||||||
| 
 | 
 | ||||||
| 	result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com") | 	result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com") | ||||||
| 	assert.False(t, result) | 	assert.False(t, result) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/rsa" | 	"crypto/rsa" | ||||||
| 	"errors" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | @ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		return nil, ErrMissingCode | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	claims := &jwt.StandardClaims{ | 	claims := &jwt.StandardClaims{ | ||||||
|  | @ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | ||||||
| 	token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) | 	token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) | ||||||
| 	ss, err := token.SignedString(p.JWTKey) | 	ss, err := token.SignedString(p.JWTKey) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	params := url.Values{} | 	params := url.Values{} | ||||||
|  | @ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | ||||||
| 	// check nonce here
 | 	// check nonce here
 | ||||||
| 	err = checkNonce(jsonResponse.IDToken, p) | 	err = checkNonce(jsonResponse.IDToken, p) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Get the email address
 | 	// Get the email address
 | ||||||
| 	var email string | 	var email string | ||||||
| 	email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String()) | 	email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() | 	created := time.Now() | ||||||
| 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) | 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) | ||||||
| 
 | 
 | ||||||
| 	// Store the data that we found in the session state
 | 	// Store the data that we found in the session state
 | ||||||
| 	s = &sessions.SessionState{ | 	return &sessions.SessionState{ | ||||||
| 		AccessToken: jsonResponse.AccessToken, | 		AccessToken: jsonResponse.AccessToken, | ||||||
| 		IDToken:     jsonResponse.IDToken, | 		IDToken:     jsonResponse.IDToken, | ||||||
| 		CreatedAt:   &created, | 		CreatedAt:   &created, | ||||||
| 		ExpiresOn:   &expires, | 		ExpiresOn:   &expires, | ||||||
| 		Email:       email, | 		Email:       email, | ||||||
| 	} | 	}, nil | ||||||
| 	return |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetLoginURL overrides GetLoginURL to add login.gov parameters
 | // GetLoginURL overrides GetLoginURL to add login.gov parameters
 | ||||||
|  |  | ||||||
|  | @ -26,6 +26,10 @@ type ProviderData struct { | ||||||
| 	ClientSecretFile string | 	ClientSecretFile string | ||||||
| 	Scope            string | 	Scope            string | ||||||
| 	Prompt           string | 	Prompt           string | ||||||
|  | 
 | ||||||
|  | 	// Universal Group authorization data structure
 | ||||||
|  | 	// any provider can set to consume
 | ||||||
|  | 	AllowedGroups map[string]struct{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Data returns the ProviderData
 | // Data returns the ProviderData
 | ||||||
|  | @ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) { | ||||||
| 	return string(fileClientSecret), nil | 	return string(fileClientSecret), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SetAllowedGroups organizes a group list into the AllowedGroups map
 | ||||||
|  | // to be consumed by Authorize implementations
 | ||||||
|  | func (p *ProviderData) SetAllowedGroups(groups []string) { | ||||||
|  | 	p.AllowedGroups = make(map[string]struct{}, len(groups)) | ||||||
|  | 	for _, group := range groups { | ||||||
|  | 		p.AllowedGroups[group] = struct{}{} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type providerDefaults struct { | type providerDefaults struct { | ||||||
| 	name        string | 	name        string | ||||||
| 	loginURL    *url.URL | 	loginURL    *url.URL | ||||||
|  |  | ||||||
|  | @ -19,18 +19,21 @@ var ( | ||||||
| 	// implementation method that doesn't have sensible defaults
 | 	// implementation method that doesn't have sensible defaults
 | ||||||
| 	ErrNotImplemented = errors.New("not implemented") | 	ErrNotImplemented = errors.New("not implemented") | ||||||
| 
 | 
 | ||||||
|  | 	// ErrMissingCode is returned when a Redeem method is called with an empty
 | ||||||
|  | 	// code
 | ||||||
|  | 	ErrMissingCode = errors.New("missing code") | ||||||
|  | 
 | ||||||
| 	_ Provider = (*ProviderData)(nil) | 	_ Provider = (*ProviderData)(nil) | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Redeem provides a default implementation of the OAuth2 token redemption process
 | // Redeem provides a default implementation of the OAuth2 token redemption process
 | ||||||
| func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		return nil, ErrMissingCode | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 	clientSecret, err := p.GetClientSecret() | 	clientSecret, err := p.GetClientSecret() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	params := url.Values{} | 	params := url.Values{} | ||||||
|  | @ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 	} | 	} | ||||||
| 	err = result.UnmarshalInto(&jsonResponse) | 	err = result.UnmarshalInto(&jsonResponse) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		s = &sessions.SessionState{ | 		return &sessions.SessionState{ | ||||||
| 			AccessToken: jsonResponse.AccessToken, | 			AccessToken: jsonResponse.AccessToken, | ||||||
| 		} | 		}, nil | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var v url.Values | 	values, err := url.ParseQuery(string(result.Body())) | ||||||
| 	v, err = url.ParseQuery(string(result.Body())) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if a := v.Get("access_token"); a != "" { | 	if token := values.Get("access_token"); token != "" { | ||||||
| 		created := time.Now() | 		created := time.Now() | ||||||
| 		s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} | 		return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil | ||||||
| 	} else { |  | ||||||
| 		err = fmt.Errorf("no access token found %s", result.Body()) |  | ||||||
| 	} | 	} | ||||||
| 	return | 
 | ||||||
|  | 	return nil, fmt.Errorf("no access token found %s", result.Body()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetLoginURL with typical oauth parameters
 | // GetLoginURL with typical oauth parameters
 | ||||||
|  | @ -92,18 +92,28 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta | ||||||
| 	return "", ErrNotImplemented | 	return "", ErrNotImplemented | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateGroup validates that the provided email exists in the configured provider
 |  | ||||||
| // email group(s).
 |  | ||||||
| func (p *ProviderData) ValidateGroup(_ string) bool { |  | ||||||
| 	return true |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // EnrichSessionState is called after Redeem to allow providers to enrich session fields
 | // EnrichSessionState is called after Redeem to allow providers to enrich session fields
 | ||||||
| // such as User, Email, Groups with provider specific API calls.
 | // such as User, Email, Groups with provider specific API calls.
 | ||||||
| func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { | func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Authorize performs global authorization on an authenticated session.
 | ||||||
|  | // This is not used for fine-grained per route authorization rules.
 | ||||||
|  | func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | 	if len(p.AllowedGroups) == 0 { | ||||||
|  | 		return true, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, group := range s.Groups { | ||||||
|  | 		if _, ok := p.AllowedGroups[group]; ok { | ||||||
|  | 			return true, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, nil) | 	return validateToken(ctx, p, s.AccessToken, nil) | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) { | ||||||
| 	s := &sessions.SessionState{} | 	s := &sessions.SessionState{} | ||||||
| 	assert.NoError(t, p.EnrichSessionState(context.Background(), s)) | 	assert.NoError(t, p.EnrichSessionState(context.Background(), s)) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestProviderDataAuthorize(t *testing.T) { | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		name          string | ||||||
|  | 		allowedGroups []string | ||||||
|  | 		groups        []string | ||||||
|  | 		expectedAuthZ bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:          "NoAllowedGroups", | ||||||
|  | 			allowedGroups: []string{}, | ||||||
|  | 			groups:        []string{}, | ||||||
|  | 			expectedAuthZ: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "NoAllowedGroupsUserHasGroups", | ||||||
|  | 			allowedGroups: []string{}, | ||||||
|  | 			groups:        []string{"foo", "bar"}, | ||||||
|  | 			expectedAuthZ: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "UserInAllowedGroup", | ||||||
|  | 			allowedGroups: []string{"foo"}, | ||||||
|  | 			groups:        []string{"foo", "bar"}, | ||||||
|  | 			expectedAuthZ: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "UserNotInAllowedGroup", | ||||||
|  | 			allowedGroups: []string{"bar"}, | ||||||
|  | 			groups:        []string{"baz", "foo"}, | ||||||
|  | 			expectedAuthZ: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range testCases { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			g := NewWithT(t) | ||||||
|  | 
 | ||||||
|  | 			session := &sessions.SessionState{ | ||||||
|  | 				Groups: tc.groups, | ||||||
|  | 			} | ||||||
|  | 			p := &ProviderData{} | ||||||
|  | 			p.SetAllowedGroups(tc.allowedGroups) | ||||||
|  | 
 | ||||||
|  | 			authorized, err := p.Authorize(context.Background(), session) | ||||||
|  | 			g.Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			g.Expect(authorized).To(Equal(tc.expectedAuthZ)) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -13,8 +13,8 @@ type Provider interface { | ||||||
| 	// DEPRECATED: Migrate to EnrichSessionState
 | 	// DEPRECATED: Migrate to EnrichSessionState
 | ||||||
| 	GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) | 	GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) | ||||||
| 	Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) | 	Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) | ||||||
| 	ValidateGroup(string) bool |  | ||||||
| 	EnrichSessionState(ctx context.Context, s *sessions.SessionState) error | 	EnrichSessionState(ctx context.Context, s *sessions.SessionState) error | ||||||
|  | 	Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||||
| 	ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool | 	ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool | ||||||
| 	GetLoginURL(redirectURI, finalRedirect string) string | 	GetLoginURL(redirectURI, finalRedirect string) string | ||||||
| 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) | 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue