Merge pull request #797 from grnhse/refactor-provider-authz
Centralize Provider authorization interface method
This commit is contained in:
		
						commit
						c377466411
					
				|  | @ -6,6 +6,11 @@ | |||
| 
 | ||||
| - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. | ||||
| - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. | ||||
| - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this | ||||
|   - Either `--google-group` or the new `--allowed-group` will work for Google now (`--google-group` will be used if both are set) | ||||
|   - Group membership lists will be passed to the backend with the `X-Forwarded-Groups` header | ||||
|   - If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately. | ||||
|       - Previously, group membership was only checked on session creation and refresh. | ||||
| - [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) `--skip-auth-route` is (almost) backwards compatible with `--skip-auth-regex` | ||||
|   - We are marking `--skip-auth-regex` as DEPRECATED and will remove it in the next major version. | ||||
|   - If your regex contains an `=` and you want it for all methods, you will need to add a leading `=` (this is the area where `--skip-auth-regex` doesn't port perfectly) | ||||
|  | @ -18,6 +23,9 @@ | |||
| ## Breaking Changes | ||||
| 
 | ||||
| - [#911](https://github.com/oauth2-proxy/oauth2-rpoxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google". | ||||
| - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow | ||||
|   - If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately. | ||||
|     - Previously, group membership was only checked on session creation and refresh. | ||||
| - [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass | ||||
| - [#800](https://github.com/oauth2-proxy/oauth2-proxy/pull/800) Fix import path for v7. The import path has changed to support the go get installation. | ||||
|   - You can now `go get github.com/oauth2-proxy/oauth2-proxy/v7` to get the latest `v7` version of OAuth2 Proxy | ||||
|  | @ -40,6 +48,7 @@ | |||
| - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves) | ||||
| - [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves) | ||||
| - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) Integrate new header injectors into project (@JoelSpeed) | ||||
| - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Create universal Authorization behavior across providers (@NickMeves) | ||||
| - [#898](https://github.com/oauth2-proxy/oauth2-proxy/pull/898) Migrate documentation to Docusaurus (@JoelSpeed) | ||||
| - [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock) | ||||
| - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) | ||||
|  |  | |||
|  | @ -42,6 +42,9 @@ var ( | |||
| 	// ErrNeedsLogin means the user should be redirected to the login page
 | ||||
| 	ErrNeedsLogin = errors.New("redirect to login page") | ||||
| 
 | ||||
| 	// ErrAccessDenied means the user should receive a 401 Unauthorized response
 | ||||
| 	ErrAccessDenied = errors.New("access denied") | ||||
| 
 | ||||
| 	// Used to check final redirects are not susceptible to open redirects.
 | ||||
| 	// Matches //, /\ and both of these with whitespace in between (eg / / or / \).
 | ||||
| 	invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`) | ||||
|  | @ -105,7 +108,6 @@ type OAuthProxy struct { | |||
| 	trustedIPs              *ip.NetSet | ||||
| 	Banner                  string | ||||
| 	Footer                  string | ||||
| 	AllowedGroups           []string | ||||
| 
 | ||||
| 	sessionChain alice.Chain | ||||
| 	headersChain alice.Chain | ||||
|  | @ -219,7 +221,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		Banner:                  opts.Banner, | ||||
| 		Footer:                  opts.Footer, | ||||
| 		SignInMessage:           buildSignInMessage(opts), | ||||
| 		AllowedGroups:           opts.AllowedGroups, | ||||
| 
 | ||||
| 		basicAuthValidator:  basicAuthValidator, | ||||
| 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | ||||
|  | @ -396,7 +397,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string { | |||
| 
 | ||||
| func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("missing code") | ||||
| 		return nil, providers.ErrMissingCode | ||||
| 	} | ||||
| 	redirectURI := p.GetRedirectURI(host) | ||||
| 	s, err := p.provider.Redeem(ctx, redirectURI, code) | ||||
|  | @ -909,11 +910,15 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | |||
| 	} | ||||
| 
 | ||||
| 	// set cookie, or deny
 | ||||
| 	if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { | ||||
| 	authorized, err := p.provider.Authorize(req.Context(), session) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error with authorization: %v", err) | ||||
| 	} | ||||
| 	if p.Validator(session.Email) && authorized { | ||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) | ||||
| 		err := p.SaveSession(rw, req, session) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Error saving session state for %s: %v", remoteAddr, err) | ||||
| 			logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) | ||||
| 			p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) | ||||
| 			return | ||||
| 		} | ||||
|  | @ -967,6 +972,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | |||
| 			p.SignInPage(rw, req, http.StatusForbidden) | ||||
| 		} | ||||
| 
 | ||||
| 	case ErrAccessDenied: | ||||
| 		p.ErrorPage(rw, http.StatusUnauthorized, "Permission Denied", "Unauthorized") | ||||
| 
 | ||||
| 	default: | ||||
| 		// unknown error
 | ||||
| 		logger.Errorf("Unexpected internal error: %v", err) | ||||
|  | @ -977,7 +985,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | |||
| } | ||||
| 
 | ||||
| // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
 | ||||
| // Returns nil, ErrNeedsLogin if user needs to login.
 | ||||
| // Returns:
 | ||||
| // - `nil, ErrNeedsLogin` if user needs to login.
 | ||||
| // - `nil, ErrAccessDenied` if the authenticated user is not authorized
 | ||||
| // Set-Cookie headers may be set on the response as a side-effect of calling this method.
 | ||||
| func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	var session *sessionsapi.SessionState | ||||
|  | @ -991,17 +1001,20 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R | |||
| 		return nil, ErrNeedsLogin | ||||
| 	} | ||||
| 
 | ||||
| 	invalidEmail := session != nil && session.Email != "" && !p.Validator(session.Email) | ||||
| 	invalidGroups := session != nil && !p.validateGroups(session.Groups) | ||||
| 	invalidEmail := session.Email != "" && !p.Validator(session.Email) | ||||
| 	authorized, err := p.provider.Authorize(req.Context(), session) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error with authorization: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if invalidEmail || invalidGroups { | ||||
| 		logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) | ||||
| 	if invalidEmail || !authorized { | ||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authorization via session: removing session %s", session) | ||||
| 		// Invalid session, clear it
 | ||||
| 		err := p.ClearSessionCookie(rw, req) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Error clearing session cookie: %v", err) | ||||
| 			logger.Errorf("Error clearing session cookie: %v", err) | ||||
| 		} | ||||
| 		return nil, ErrNeedsLogin | ||||
| 		return nil, ErrAccessDenied | ||||
| 	} | ||||
| 
 | ||||
| 	return session, nil | ||||
|  | @ -1033,23 +1046,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { | |||
| 	rw.Header().Set("Content-Type", applicationJSON) | ||||
| 	rw.WriteHeader(code) | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) validateGroups(groups []string) bool { | ||||
| 	if len(p.AllowedGroups) == 0 { | ||||
| 		return true | ||||
| 	} | ||||
| 
 | ||||
| 	allowedGroups := map[string]struct{}{} | ||||
| 
 | ||||
| 	for _, group := range p.AllowedGroups { | ||||
| 		allowedGroups[group] = struct{}{} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, group := range groups { | ||||
| 		if _, ok := allowedGroups[group]; ok { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
|  |  | |||
|  | @ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi | |||
| 		return nil, err | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   opts.providerValidateCookieResponse, | ||||
| 	} | ||||
| 	pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups) | ||||
| 
 | ||||
| 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | ||||
| 	// access_token validation.
 | ||||
|  | @ -1284,6 +1286,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | |||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   true, | ||||
| 	} | ||||
| 
 | ||||
|  | @ -1376,6 +1379,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { | |||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   true, | ||||
| 	} | ||||
| 
 | ||||
|  | @ -1455,6 +1459,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { | |||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   true, | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | |||
| 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | ||||
| 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | ||||
| 
 | ||||
| 	p.SetAllowedGroups(o.AllowedGroups) | ||||
| 
 | ||||
| 	provider := providers.New(o.ProviderType, p) | ||||
| 	if provider == nil { | ||||
| 		msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) | ||||
|  | @ -255,7 +257,13 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | |||
| 			if err != nil { | ||||
| 				msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) | ||||
| 			} else { | ||||
| 				p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file) | ||||
| 				groups := o.AllowedGroups | ||||
| 				// Backwards compatibility with `--google-group` option
 | ||||
| 				if len(o.GoogleGroups) > 0 { | ||||
| 					groups = o.GoogleGroups | ||||
| 					p.SetAllowedGroups(groups) | ||||
| 				} | ||||
| 				p.SetGroupRestriction(groups, o.GoogleAdminEmail, file) | ||||
| 			} | ||||
| 		} | ||||
| 	case *providers.BitbucketProvider: | ||||
|  |  | |||
|  | @ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) { | |||
| } | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
| func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
| 		return nil, ErrMissingCode | ||||
| 	} | ||||
| 	clientSecret, err := p.GetClientSecret() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	params := url.Values{} | ||||
|  | @ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | |||
| 
 | ||||
| 	created := time.Now() | ||||
| 	expires := time.Unix(jsonResponse.ExpiresOn, 0) | ||||
| 	s = &sessions.SessionState{ | ||||
| 
 | ||||
| 	return &sessions.SessionState{ | ||||
| 		AccessToken:  jsonResponse.AccessToken, | ||||
| 		IDToken:      jsonResponse.IDToken, | ||||
| 		CreatedAt:    &created, | ||||
| 		ExpiresOn:    &expires, | ||||
| 		RefreshToken: jsonResponse.RefreshToken, | ||||
| 	} | ||||
| 	return | ||||
| 
 | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||
|  |  | |||
|  | @ -25,10 +25,16 @@ import ( | |||
| // GoogleProvider represents an Google based Identity Provider
 | ||||
| type GoogleProvider struct { | ||||
| 	*ProviderData | ||||
| 
 | ||||
| 	RedeemRefreshURL *url.URL | ||||
| 	// GroupValidator is a function that determines if the passed email is in
 | ||||
| 	// the configured Google group.
 | ||||
| 	GroupValidator func(string) bool | ||||
| 
 | ||||
| 	// groupValidator is a function that determines if the user in the passed
 | ||||
| 	// session is a member of any of the configured Google groups.
 | ||||
| 	//
 | ||||
| 	// This hits the Google API for each group, so it is called on Redeem &
 | ||||
| 	// Refresh. `Authorize` uses the results of this saved in `session.Groups`
 | ||||
| 	// Since it is called on every request.
 | ||||
| 	groupValidator func(*sessions.SessionState) bool | ||||
| } | ||||
| 
 | ||||
| var _ Provider = (*GoogleProvider)(nil) | ||||
|  | @ -84,9 +90,9 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { | |||
| 	}) | ||||
| 	return &GoogleProvider{ | ||||
| 		ProviderData: p, | ||||
| 		// Set a default GroupValidator to just always return valid (true), it will
 | ||||
| 		// Set a default groupValidator to just always return valid (true), it will
 | ||||
| 		// be overwritten if we configured a Google group restriction.
 | ||||
| 		GroupValidator: func(email string) bool { | ||||
| 		groupValidator: func(*sessions.SessionState) bool { | ||||
| 			return true | ||||
| 		}, | ||||
| 	} | ||||
|  | @ -118,14 +124,13 @@ func claimsFromIDToken(idToken string) (*claims, error) { | |||
| } | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
| func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
| 		return nil, ErrMissingCode | ||||
| 	} | ||||
| 	clientSecret, err := p.GetClientSecret() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	params := url.Values{} | ||||
|  | @ -155,12 +160,13 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | |||
| 
 | ||||
| 	c, err := claimsFromIDToken(jsonResponse.IDToken) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) | ||||
| 	s = &sessions.SessionState{ | ||||
| 
 | ||||
| 	return &sessions.SessionState{ | ||||
| 		AccessToken:  jsonResponse.AccessToken, | ||||
| 		IDToken:      jsonResponse.IDToken, | ||||
| 		CreatedAt:    &created, | ||||
|  | @ -168,18 +174,40 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | |||
| 		RefreshToken: jsonResponse.RefreshToken, | ||||
| 		Email:        c.Email, | ||||
| 		User:         c.Subject, | ||||
| 	}, nil | ||||
| } | ||||
| 	return | ||||
| 
 | ||||
| // EnrichSessionState checks the listed Google Groups configured and adds any
 | ||||
| // that the user is a member of to session.Groups.
 | ||||
| func (p *GoogleProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error { | ||||
| 	// TODO (@NickMeves) - Move to pure EnrichSessionState logic and stop
 | ||||
| 	// reusing legacy `groupValidator`.
 | ||||
| 	//
 | ||||
| 	// This is called here to get the validator to do the `session.Groups`
 | ||||
| 	// populating logic.
 | ||||
| 	p.groupValidator(s) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // SetGroupRestriction configures the GoogleProvider to restrict access to the
 | ||||
| // specified group(s). AdminEmail has to be an administrative email on the domain that is
 | ||||
| // checked. CredentialsFile is the path to a json file containing a Google service
 | ||||
| // account credentials.
 | ||||
| //
 | ||||
| // TODO (@NickMeves) - Unit Test this OR refactor away from groupValidator func
 | ||||
| func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { | ||||
| 	adminService := getAdminService(adminEmail, credentialsReader) | ||||
| 	p.GroupValidator = func(email string) bool { | ||||
| 		return userInGroup(adminService, groups, email) | ||||
| 	p.groupValidator = func(s *sessions.SessionState) bool { | ||||
| 		// Reset our saved Groups in case membership changed
 | ||||
| 		// This is used by `Authorize` on every request
 | ||||
| 		s.Groups = make([]string, 0, len(groups)) | ||||
| 		for _, group := range groups { | ||||
| 			if userInGroup(adminService, group, s.Email) { | ||||
| 				s.Groups = append(s.Groups, group) | ||||
| 			} | ||||
| 		} | ||||
| 		return len(s.Groups) > 0 | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -203,12 +231,14 @@ func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Serv | |||
| 	return adminService | ||||
| } | ||||
| 
 | ||||
| func userInGroup(service *admin.Service, groups []string, email string) bool { | ||||
| 	for _, group := range groups { | ||||
| func userInGroup(service *admin.Service, group string, email string) bool { | ||||
| 	// Use the HasMember API to checking for the user's presence in each group or nested subgroups
 | ||||
| 	req := service.Members.HasMember(group, email) | ||||
| 	r, err := req.Do() | ||||
| 		if err != nil { | ||||
| 	if err == nil { | ||||
| 		return r.IsMember | ||||
| 	} | ||||
| 
 | ||||
| 	gerr, ok := err.(*googleapi.Error) | ||||
| 	switch { | ||||
| 	case ok && gerr.Code == 404: | ||||
|  | @ -220,10 +250,9 @@ func userInGroup(service *admin.Service, groups []string, email string) bool { | |||
| 		// from the HasMember API. In that case, attempt to query the member object directly from the group.
 | ||||
| 		req := service.Members.Get(group, email) | ||||
| 		r, err := req.Do() | ||||
| 
 | ||||
| 		if err != nil { | ||||
| 			logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group) | ||||
| 					continue | ||||
| 			return false | ||||
| 		} | ||||
| 
 | ||||
| 		// If the non-domain user is found within the group, still verify that they are "ACTIVE".
 | ||||
|  | @ -234,21 +263,9 @@ func userInGroup(service *admin.Service, groups []string, email string) bool { | |||
| 	default: | ||||
| 		logger.Errorf("error checking group membership: %v", err) | ||||
| 	} | ||||
| 			continue | ||||
| 		} | ||||
| 		if r.IsMember { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // ValidateGroup validates that the provided email exists in the configured Google
 | ||||
| // group(s).
 | ||||
| func (p *GoogleProvider) ValidateGroup(email string) bool { | ||||
| 	return p.GroupValidator(email) | ||||
| } | ||||
| 
 | ||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||
| // RefreshToken to fetch a new ID token if required
 | ||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||
|  | @ -261,8 +278,11 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions | |||
| 		return false, err | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO (@NickMeves) - Align Group authorization needs with other providers'
 | ||||
| 	// behavior in the `RefreshSession` case.
 | ||||
| 	//
 | ||||
| 	// re-check that the user is in the proper google group(s)
 | ||||
| 	if !p.ValidateGroup(s.Email) { | ||||
| 	if !p.groupValidator(s) { | ||||
| 		return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -10,6 +10,7 @@ import ( | |||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/gomega" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	admin "google.golang.org/api/admin/directory/v1" | ||||
|  | @ -109,21 +110,50 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { | |||
| 	assert.Equal(t, "refresh12345", session.RefreshToken) | ||||
| } | ||||
| 
 | ||||
| func TestGoogleProviderValidateGroup(t *testing.T) { | ||||
| 	p := newGoogleProvider() | ||||
| 	p.GroupValidator = func(email string) bool { | ||||
| 		return email == "michael.bland@gsa.gov" | ||||
| 	} | ||||
| 	assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) | ||||
| 	p.GroupValidator = func(email string) bool { | ||||
| 		return email != "michael.bland@gsa.gov" | ||||
| 	} | ||||
| 	assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov")) | ||||
| } | ||||
| func TestGoogleProviderGroupValidator(t *testing.T) { | ||||
| 	const sessionEmail = "michael.bland@gsa.gov" | ||||
| 
 | ||||
| func TestGoogleProviderWithoutValidateGroup(t *testing.T) { | ||||
| 	testCases := map[string]struct { | ||||
| 		session       *sessions.SessionState | ||||
| 		validatorFunc func(*sessions.SessionState) bool | ||||
| 		expectedAuthZ bool | ||||
| 	}{ | ||||
| 		"Email is authorized with groupValidator": { | ||||
| 			session: &sessions.SessionState{ | ||||
| 				Email: sessionEmail, | ||||
| 			}, | ||||
| 			validatorFunc: func(s *sessions.SessionState) bool { | ||||
| 				return s.Email == sessionEmail | ||||
| 			}, | ||||
| 			expectedAuthZ: true, | ||||
| 		}, | ||||
| 		"Email is denied with groupValidator": { | ||||
| 			session: &sessions.SessionState{ | ||||
| 				Email: sessionEmail, | ||||
| 			}, | ||||
| 			validatorFunc: func(s *sessions.SessionState) bool { | ||||
| 				return s.Email != sessionEmail | ||||
| 			}, | ||||
| 			expectedAuthZ: false, | ||||
| 		}, | ||||
| 		"Default does no authorization checks": { | ||||
| 			session: &sessions.SessionState{ | ||||
| 				Email: sessionEmail, | ||||
| 			}, | ||||
| 			validatorFunc: nil, | ||||
| 			expectedAuthZ: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	for name, tc := range testCases { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			g := NewWithT(t) | ||||
| 			p := newGoogleProvider() | ||||
| 	assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) | ||||
| 			if tc.validatorFunc != nil { | ||||
| 				p.groupValidator = tc.validatorFunc | ||||
| 			} | ||||
| 			g.Expect(p.groupValidator(tc.session)).To(Equal(tc.expectedAuthZ)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
|  | @ -196,7 +226,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { | |||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestGoogleProviderUserInGroup(t *testing.T) { | ||||
| func TestGoogleProvider_userInGroup(t *testing.T) { | ||||
| 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		if r.URL.Path == "/groups/group@example.com/hasMember/member-in-domain@example.com" { | ||||
| 			fmt.Fprintln(w, `{"isMember": true}`) | ||||
|  | @ -233,18 +263,19 @@ func TestGoogleProviderUserInGroup(t *testing.T) { | |||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	service, err := admin.NewService(ctx, option.WithHTTPClient(client)) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	service.BasePath = ts.URL | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	result := userInGroup(service, []string{"group@example.com"}, "member-in-domain@example.com") | ||||
| 	result := userInGroup(service, "group@example.com", "member-in-domain@example.com") | ||||
| 	assert.True(t, result) | ||||
| 
 | ||||
| 	result = userInGroup(service, []string{"group@example.com"}, "member-out-of-domain@otherexample.com") | ||||
| 	result = userInGroup(service, "group@example.com", "member-out-of-domain@otherexample.com") | ||||
| 	assert.True(t, result) | ||||
| 
 | ||||
| 	result = userInGroup(service, []string{"group@example.com"}, "non-member-in-domain@example.com") | ||||
| 	result = userInGroup(service, "group@example.com", "non-member-in-domain@example.com") | ||||
| 	assert.False(t, result) | ||||
| 
 | ||||
| 	result = userInGroup(service, []string{"group@example.com"}, "non-member-out-of-domain@otherexample.com") | ||||
| 	result = userInGroup(service, "group@example.com", "non-member-out-of-domain@otherexample.com") | ||||
| 	assert.False(t, result) | ||||
| } | ||||
|  |  | |||
|  | @ -4,7 +4,6 @@ import ( | |||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"crypto/rsa" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"net/url" | ||||
|  | @ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint | |||
| } | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
| func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
| 		return nil, ErrMissingCode | ||||
| 	} | ||||
| 
 | ||||
| 	claims := &jwt.StandardClaims{ | ||||
|  | @ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | |||
| 	token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) | ||||
| 	ss, err := token.SignedString(p.JWTKey) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	params := url.Values{} | ||||
|  | @ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | |||
| 	// check nonce here
 | ||||
| 	err = checkNonce(jsonResponse.IDToken, p) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Get the email address
 | ||||
| 	var email string | ||||
| 	email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String()) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) | ||||
| 
 | ||||
| 	// Store the data that we found in the session state
 | ||||
| 	s = &sessions.SessionState{ | ||||
| 	return &sessions.SessionState{ | ||||
| 		AccessToken: jsonResponse.AccessToken, | ||||
| 		IDToken:     jsonResponse.IDToken, | ||||
| 		CreatedAt:   &created, | ||||
| 		ExpiresOn:   &expires, | ||||
| 		Email:       email, | ||||
| 	} | ||||
| 	return | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| // GetLoginURL overrides GetLoginURL to add login.gov parameters
 | ||||
|  |  | |||
|  | @ -26,6 +26,10 @@ type ProviderData struct { | |||
| 	ClientSecretFile string | ||||
| 	Scope            string | ||||
| 	Prompt           string | ||||
| 
 | ||||
| 	// Universal Group authorization data structure
 | ||||
| 	// any provider can set to consume
 | ||||
| 	AllowedGroups map[string]struct{} | ||||
| } | ||||
| 
 | ||||
| // Data returns the ProviderData
 | ||||
|  | @ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) { | |||
| 	return string(fileClientSecret), nil | ||||
| } | ||||
| 
 | ||||
| // SetAllowedGroups organizes a group list into the AllowedGroups map
 | ||||
| // to be consumed by Authorize implementations
 | ||||
| func (p *ProviderData) SetAllowedGroups(groups []string) { | ||||
| 	p.AllowedGroups = make(map[string]struct{}, len(groups)) | ||||
| 	for _, group := range groups { | ||||
| 		p.AllowedGroups[group] = struct{}{} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type providerDefaults struct { | ||||
| 	name        string | ||||
| 	loginURL    *url.URL | ||||
|  |  | |||
|  | @ -19,18 +19,21 @@ var ( | |||
| 	// implementation method that doesn't have sensible defaults
 | ||||
| 	ErrNotImplemented = errors.New("not implemented") | ||||
| 
 | ||||
| 	// ErrMissingCode is returned when a Redeem method is called with an empty
 | ||||
| 	// code
 | ||||
| 	ErrMissingCode = errors.New("missing code") | ||||
| 
 | ||||
| 	_ Provider = (*ProviderData)(nil) | ||||
| ) | ||||
| 
 | ||||
| // Redeem provides a default implementation of the OAuth2 token redemption process
 | ||||
| func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
| 		return nil, ErrMissingCode | ||||
| 	} | ||||
| 	clientSecret, err := p.GetClientSecret() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	params := url.Values{} | ||||
|  | @ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | |||
| 	} | ||||
| 	err = result.UnmarshalInto(&jsonResponse) | ||||
| 	if err == nil { | ||||
| 		s = &sessions.SessionState{ | ||||
| 		return &sessions.SessionState{ | ||||
| 			AccessToken: jsonResponse.AccessToken, | ||||
| 		} | ||||
| 		return | ||||
| 		}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	var v url.Values | ||||
| 	v, err = url.ParseQuery(string(result.Body())) | ||||
| 	values, err := url.ParseQuery(string(result.Body())) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if a := v.Get("access_token"); a != "" { | ||||
| 	if token := values.Get("access_token"); token != "" { | ||||
| 		created := time.Now() | ||||
| 		s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} | ||||
| 	} else { | ||||
| 		err = fmt.Errorf("no access token found %s", result.Body()) | ||||
| 		return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil | ||||
| 	} | ||||
| 	return | ||||
| 
 | ||||
| 	return nil, fmt.Errorf("no access token found %s", result.Body()) | ||||
| } | ||||
| 
 | ||||
| // GetLoginURL with typical oauth parameters
 | ||||
|  | @ -92,18 +92,28 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta | |||
| 	return "", ErrNotImplemented | ||||
| } | ||||
| 
 | ||||
| // ValidateGroup validates that the provided email exists in the configured provider
 | ||||
| // email group(s).
 | ||||
| func (p *ProviderData) ValidateGroup(_ string) bool { | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| // EnrichSessionState is called after Redeem to allow providers to enrich session fields
 | ||||
| // such as User, Email, Groups with provider specific API calls.
 | ||||
| func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Authorize performs global authorization on an authenticated session.
 | ||||
| // This is not used for fine-grained per route authorization rules.
 | ||||
| func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 	if len(p.AllowedGroups) == 0 { | ||||
| 		return true, nil | ||||
| 	} | ||||
| 
 | ||||
| 	for _, group := range s.Groups { | ||||
| 		if _, ok := p.AllowedGroups[group]; ok { | ||||
| 			return true, nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false, nil | ||||
| } | ||||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
| func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||
| 	return validateToken(ctx, p, s.AccessToken, nil) | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/gomega" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) { | |||
| 	s := &sessions.SessionState{} | ||||
| 	assert.NoError(t, p.EnrichSessionState(context.Background(), s)) | ||||
| } | ||||
| 
 | ||||
| func TestProviderDataAuthorize(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		name          string | ||||
| 		allowedGroups []string | ||||
| 		groups        []string | ||||
| 		expectedAuthZ bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:          "NoAllowedGroups", | ||||
| 			allowedGroups: []string{}, | ||||
| 			groups:        []string{}, | ||||
| 			expectedAuthZ: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "NoAllowedGroupsUserHasGroups", | ||||
| 			allowedGroups: []string{}, | ||||
| 			groups:        []string{"foo", "bar"}, | ||||
| 			expectedAuthZ: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "UserInAllowedGroup", | ||||
| 			allowedGroups: []string{"foo"}, | ||||
| 			groups:        []string{"foo", "bar"}, | ||||
| 			expectedAuthZ: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "UserNotInAllowedGroup", | ||||
| 			allowedGroups: []string{"bar"}, | ||||
| 			groups:        []string{"baz", "foo"}, | ||||
| 			expectedAuthZ: false, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			g := NewWithT(t) | ||||
| 
 | ||||
| 			session := &sessions.SessionState{ | ||||
| 				Groups: tc.groups, | ||||
| 			} | ||||
| 			p := &ProviderData{} | ||||
| 			p.SetAllowedGroups(tc.allowedGroups) | ||||
| 
 | ||||
| 			authorized, err := p.Authorize(context.Background(), session) | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 			g.Expect(authorized).To(Equal(tc.expectedAuthZ)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -13,8 +13,8 @@ type Provider interface { | |||
| 	// DEPRECATED: Migrate to EnrichSessionState
 | ||||
| 	GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) | ||||
| 	Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) | ||||
| 	ValidateGroup(string) bool | ||||
| 	EnrichSessionState(ctx context.Context, s *sessions.SessionState) error | ||||
| 	Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||
| 	ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool | ||||
| 	GetLoginURL(redirectURI, finalRedirect string) string | ||||
| 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue