diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f09d929..d5111659 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,9 +13,13 @@ ## Breaking Changes +- [#1239](https://github.com/oauth2-proxy/oauth2-proxy/pull/1239) GitLab groups sent in the `X-Forwarded-Groups` header + to the upstream server will no longer be prefixed with `group:` + ## Changes since v7.1.3 - [#1337](https://github.com/oauth2-proxy/oauth2-proxy/pull/1337) Changing user field type to text when using htpasswd (@pburgisser) +- [#1239](https://github.com/oauth2-proxy/oauth2-proxy/pull/1239) Base GitLab provider implementation on OIDCProvider (@NickMeves) - [#1276](https://github.com/oauth2-proxy/oauth2-proxy/pull/1276) Update crypto and switched to new github.com/golang-jwt/jwt (@JVecsei) - [#1264](https://github.com/oauth2-proxy/oauth2-proxy/pull/1264) Update go-oidc to v3 (@NickMeves) - [#1233](https://github.com/oauth2-proxy/oauth2-proxy/pull/1233) Extend email-domain validation with sub-domain capability (@morarucostel) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 37c2aa24..ce1d219e 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -276,13 +276,11 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { msgs = append(msgs, "oidc provider requires an oidc issuer URL") } case *providers.GitLabProvider: - p.Groups = o.Providers[0].GitLabConfig.Group - err := p.AddProjects(o.Providers[0].GitLabConfig.Projects) + p.SetAllowedGroups(o.Providers[0].GitLabConfig.Group) + err := p.SetAllowedProjects(o.Providers[0].GitLabConfig.Projects) if err != nil { msgs = append(msgs, "failed to setup gitlab project access level") } - p.SetAllowedGroups(p.PrefixAllowedGroups()) - p.SetProjectScope() if p.Verifier == nil { // Initialize with default verifier for gitlab.com diff --git a/providers/gitlab.go b/providers/gitlab.go index a2b11df7..b73e19d3 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -6,194 +6,209 @@ import ( "net/url" "strconv" "strings" - "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" - "golang.org/x/oauth2" +) + +const ( + gitlabProviderName = "GitLab" + gitlabDefaultScope = "openid email" + gitlabProjectPrefix = "project:" ) // GitLabProvider represents a GitLab based Identity Provider type GitLabProvider struct { - *ProviderData + *OIDCProvider - Groups []string - Projects []*GitlabProject + allowedProjects []*gitlabProject + // Expose this for unit testing + oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error) } -// GitlabProject represents a Gitlab project constraint entity -type GitlabProject struct { +var _ Provider = (*GitLabProvider)(nil) + +// NewGitLabProvider initiates a new GitLabProvider +func NewGitLabProvider(p *ProviderData) *GitLabProvider { + p.ProviderName = gitlabProviderName + if p.Scope == "" { + p.Scope = gitlabDefaultScope + } + + oidcProvider := &OIDCProvider{ + ProviderData: p, + SkipNonce: false, + } + + return &GitLabProvider{ + OIDCProvider: oidcProvider, + oidcRefreshFunc: oidcProvider.RefreshSession, + } +} + +// SetAllowedProjects adds Gitlab projects to the AllowedGroups list +// and tracks them to do a project API lookup during `EnrichSession`. +func (p *GitLabProvider) SetAllowedProjects(projects []string) error { + for _, project := range projects { + gp, err := newGitlabProject(project) + if err != nil { + return err + } + p.allowedProjects = append(p.allowedProjects, gp) + p.AllowedGroups[formatProject(gp)] = struct{}{} + } + if len(p.allowedProjects) > 0 { + p.setProjectScope() + } + return nil +} + +// gitlabProject represents a Gitlab project constraint entity +type gitlabProject struct { Name string AccessLevel int } -// newGitlabProject Creates a new GitlabProject struct from project string formatted as namespace/project=accesslevel +// newGitlabProject Creates a new GitlabProject struct from project string +// formatted as `namespace/project=accesslevel` // if no accesslevel provided, use the default one -func newGitlabproject(project string) (*GitlabProject, error) { - // default access level is 20 - defaultAccessLevel := 20 +func newGitlabProject(project string) (*gitlabProject, error) { + const defaultAccessLevel = 20 // see https://docs.gitlab.com/ee/api/members.html#valid-access-levels validAccessLevel := [4]int{10, 20, 30, 40} parts := strings.SplitN(project, "=", 2) - if len(parts) == 2 { lvl, err := strconv.Atoi(parts[1]) if err != nil { return nil, err } - for _, valid := range validAccessLevel { if lvl == valid { - return &GitlabProject{ - Name: parts[0], - AccessLevel: lvl}, - err + return &gitlabProject{ + Name: parts[0], + AccessLevel: lvl, + }, nil } } - return nil, fmt.Errorf("invalid gitlab project access level specified (%s)", parts[0]) - } - return &GitlabProject{ - Name: project, - AccessLevel: defaultAccessLevel}, - nil - + return &gitlabProject{ + Name: project, + AccessLevel: defaultAccessLevel, + }, nil } -var _ Provider = (*GitLabProvider)(nil) - -const ( - gitlabProviderName = "GitLab" - gitlabDefaultScope = "openid email" -) - -// NewGitLabProvider initiates a new GitLabProvider -func NewGitLabProvider(p *ProviderData) *GitLabProvider { - p.ProviderName = gitlabProviderName - - if p.Scope == "" { - p.Scope = gitlabDefaultScope - } - - return &GitLabProvider{ProviderData: p} -} - -// Redeem exchanges the OAuth2 authentication token for an ID token -func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { - clientSecret, err := p.GetClientSecret() - if err != nil { - return - } - - c := oauth2.Config{ - ClientID: p.ClientID, - ClientSecret: clientSecret, - Endpoint: oauth2.Endpoint{ - TokenURL: p.RedeemURL.String(), - }, - RedirectURL: redirectURL, - } - token, err := c.Exchange(ctx, code) - if err != nil { - return nil, fmt.Errorf("token exchange: %v", err) - } - s, err = p.createSession(ctx, token) - if err != nil { - return nil, fmt.Errorf("unable to update session: %v", err) - } - return -} - -// SetProjectScope ensure read_api is added to scope when filtering on projects -func (p *GitLabProvider) SetProjectScope() { - if len(p.Projects) > 0 { - for _, val := range strings.Split(p.Scope, " ") { - if val == "read_api" { - return - } - +// setProjectScope ensures read_api is added to scope when filtering on projects +func (p *GitLabProvider) setProjectScope() { + for _, val := range strings.Split(p.Scope, " ") { + if val == "read_api" { + return } - p.Scope += " read_api" } + p.Scope += " read_api" } -// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens -func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || s.RefreshToken == "" { - return false, nil - } - - origExpiration := s.ExpiresOn - - err := p.redeemRefreshToken(ctx, s) +// EnrichSession enriches the session with the response from the userinfo API +// endpoint & projects API endpoint for allowed projects. +func (p *GitLabProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { + // Retrieve user info + userinfo, err := p.getUserinfo(ctx, s) if err != nil { - return false, fmt.Errorf("unable to redeem refresh token: %v", err) + return fmt.Errorf("failed to retrieve user info: %v", err) } - logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) - return true, nil -} - -func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { - clientSecret, err := p.GetClientSecret() - if err != nil { - return err + // Check if email is verified + if !p.AllowUnverifiedEmail && !userinfo.EmailVerified { + return fmt.Errorf("user email is not verified") } - c := oauth2.Config{ - ClientID: p.ClientID, - ClientSecret: clientSecret, - Endpoint: oauth2.Endpoint{ - TokenURL: p.RedeemURL.String(), - }, + if userinfo.Nickname != "" { + s.User = userinfo.Nickname } - t := &oauth2.Token{ - RefreshToken: s.RefreshToken, - Expiry: time.Now().Add(-time.Hour), + if userinfo.Email != "" { + s.Email = userinfo.Email } - token, err := c.TokenSource(ctx, t).Token() - if err != nil { - return fmt.Errorf("failed to get token: %v", err) + if len(userinfo.Groups) > 0 { + s.Groups = userinfo.Groups } - newSession, err := p.createSession(ctx, token) - if err != nil { - return fmt.Errorf("unable to update session: %v", err) - } - *s = *newSession + + // Add projects as `project:blah` to s.Groups + p.addProjectsToSession(ctx, s) return nil } -type gitlabUserInfo struct { - Username string `json:"nickname"` +type gitlabUserinfo struct { + Nickname string `json:"nickname"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` Groups []string `json:"groups"` } -func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionState) (*gitlabUserInfo, error) { +func (p *GitLabProvider) getUserinfo(ctx context.Context, s *sessions.SessionState) (*gitlabUserinfo, error) { // Retrieve user info JSON // https://docs.gitlab.com/ee/integration/openid_connect_provider.html#shared-information // Build user info url from login url of GitLab instance - userInfoURL := *p.LoginURL - userInfoURL.Path = "/oauth/userinfo" + userinfoURL := *p.LoginURL + userinfoURL.Path = "/oauth/userinfo" - var userInfo gitlabUserInfo - err := requests.New(userInfoURL.String()). + var userinfo gitlabUserinfo + err := requests.New(userinfoURL.String()). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). Do(). - UnmarshalInto(&userInfo) + UnmarshalInto(&userinfo) if err != nil { return nil, fmt.Errorf("error getting user info: %v", err) } - return &userInfo, nil + return &userinfo, nil +} + +// addProjectsToSession adds projects matching user access requirements into +// the session state groups list. +// This method prefixes projects names with `project:` to specify group kind. +func (p *GitLabProvider) addProjectsToSession(ctx context.Context, s *sessions.SessionState) { + // Iterate over projects, check if oauth2-proxy can get project information on behalf of the user + for _, project := range p.allowedProjects { + projectInfo, err := p.getProjectInfo(ctx, s, project.Name) + if err != nil { + logger.Errorf("Warning: project info request failed: %v", err) + continue + } + + if projectInfo.Archived { + logger.Errorf("Warning: project %s is archived", project.Name) + continue + } + + perms := projectInfo.Permissions.ProjectAccess + if perms == nil { + // use group project access as fallback + perms = projectInfo.Permissions.GroupAccess + // group project access is not set for this user then we give up + if perms == nil { + logger.Errorf("Warning: user %q has no project level access to %s", + s.Email, project.Name) + continue + } + } + + if perms.AccessLevel < project.AccessLevel { + logger.Errorf( + "Warning: user %q does not have the minimum required access level for project %q", + s.Email, + project.Name, + ) + continue + } + + s.Groups = append(s.Groups, formatProject(project)) + } } type gitlabPermissionAccess struct { @@ -226,7 +241,6 @@ func (p *GitLabProvider) getProjectInfo(ctx context.Context, s *sessions.Session SetHeader("Authorization", "Bearer "+s.AccessToken). Do(). UnmarshalInto(&projectInfo) - if err != nil { return nil, fmt.Errorf("failed to get project info: %v", err) } @@ -234,116 +248,45 @@ func (p *GitLabProvider) getProjectInfo(ctx context.Context, s *sessions.Session return &projectInfo, nil } -// AddProjects adds Gitlab projects from options to GitlabProvider struct -func (p *GitLabProvider) AddProjects(projects []string) error { - for _, project := range projects { - gp, err := newGitlabproject(project) - if err != nil { - return err - } - - p.Projects = append(p.Projects, gp) - } - - return nil +func formatProject(project *gitlabProject) string { + return gitlabProjectPrefix + project.Name } -func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { - idToken, err := p.verifyIDToken(ctx, token) - if err != nil { - switch err { - case ErrMissingIDToken: - return nil, fmt.Errorf("token response did not contain an id_token") - default: - return nil, fmt.Errorf("could not verify id_token: %v", err) +// RefreshSession refreshes the session with the OIDCProvider implementation +// but preserves the custom GitLab projects added in the `EnrichSession` stage. +func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + nickname := s.User + projects := getSessionProjects(s) + // This will overwrite s.Groups with the new IDToken's `groups` claims + // and s.User with the `sub` claim. + refreshed, err := p.oidcRefreshFunc(ctx, s) + if refreshed && err == nil { + s.User = nickname + s.Groups = append(s.Groups, projects...) + s.Groups = deduplicateGroups(s.Groups) + } + return refreshed, err +} + +func getSessionProjects(s *sessions.SessionState) []string { + var projects []string + for _, group := range s.Groups { + if strings.HasPrefix(group, gitlabProjectPrefix) { + projects = append(projects, group) } } - - ss := &sessions.SessionState{ - AccessToken: token.AccessToken, - IDToken: getIDToken(token), - RefreshToken: token.RefreshToken, - } - - ss.CreatedAtNow() - ss.SetExpiresOn(idToken.Expiry) - - return ss, nil + return projects } -// ValidateSession checks that the session's IDToken is still valid -func (p *GitLabProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { - _, err := p.Verifier.Verify(ctx, s.IDToken) - return err == nil -} - -// EnrichSession adds values and data from the Gitlab endpoint to current session -func (p *GitLabProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { - // Retrieve user info - userInfo, err := p.getUserInfo(ctx, s) - if err != nil { - return fmt.Errorf("failed to retrieve user info: %v", err) - } - - // Check if email is verified - if !p.AllowUnverifiedEmail && !userInfo.EmailVerified { - return fmt.Errorf("user email is not verified") - } - - s.User = userInfo.Username - s.Email = userInfo.Email - for _, group := range userInfo.Groups { - s.Groups = append(s.Groups, fmt.Sprintf("group:%s", group)) - } - - p.addProjectsToSession(ctx, s) - - return nil -} - -// addProjectsToSession adds projects matching user access requirements into the session state groups list -// This method prefix projects names with `project` to specify group kind -func (p *GitLabProvider) addProjectsToSession(ctx context.Context, s *sessions.SessionState) { - // Iterate over projects, check if oauth2-proxy can get project information on behalf of the user - for _, project := range p.Projects { - projectInfo, err := p.getProjectInfo(ctx, s, project.Name) - - if err != nil { - logger.Errorf("Warning: project info request failed: %v", err) - continue - } - - if !projectInfo.Archived { - perms := projectInfo.Permissions.ProjectAccess - if perms == nil { - // use group project access as fallback - perms = projectInfo.Permissions.GroupAccess - // group project access is not set for this user then we give up - if perms == nil { - logger.Errorf("Warning: user %q has no project level access to %s", s.Email, project.Name) - continue - } - } - - if perms != nil && perms.AccessLevel >= project.AccessLevel { - s.Groups = append(s.Groups, fmt.Sprintf("project:%s", project.Name)) - } else { - logger.Errorf("Warning: user %q does not have the minimum required access level for project %q", s.Email, project.Name) - } - continue - } - - logger.Errorf("Warning: project %s is archived", project.Name) - } -} - -// PrefixAllowedGroups returns a list of allowed groups, prefixed by their `kind` value -func (p *GitLabProvider) PrefixAllowedGroups() (groups []string) { - for _, val := range p.Groups { - groups = append(groups, fmt.Sprintf("group:%s", val)) - } - for _, val := range p.Projects { - groups = append(groups, fmt.Sprintf("project:%s", val.Name)) - } - return groups +func deduplicateGroups(groups []string) []string { + groupSet := make(map[string]struct{}) + for _, group := range groups { + groupSet[group] = struct{}{} + } + + uniqueGroups := make([]string, 0, len(groupSet)) + for group := range groupSet { + uniqueGroups = append(uniqueGroups, group) + } + return uniqueGroups } diff --git a/providers/gitlab_test.go b/providers/gitlab_test.go index 9c3fb00e..d6635c27 100644 --- a/providers/gitlab_test.go +++ b/providers/gitlab_test.go @@ -188,7 +188,7 @@ var _ = Describe("Gitlab Provider Tests", func() { err := p.EnrichSession(context.Background(), session) if in.expectedError != nil { - Expect(err).To(MatchError(err)) + Expect(err).To(MatchError(in.expectedError)) } else { Expect(err).To(BeNil()) Expect(session.Email).To(Equal(in.expectedValue)) @@ -208,98 +208,165 @@ var _ = Describe("Gitlab Provider Tests", func() { Context("when filtering on gitlab entities (groups and projects)", func() { type entitiesTableInput struct { - expectedValue []string - projects []string - groups []string + allowedProjects []string + allowedGroups []string + scope string + expectedAuthz bool + expectedError error + expectedGroups []string + expectedScope string } DescribeTable("should return expected results", func(in entitiesTableInput) { p.AllowUnverifiedEmail = true + if in.scope != "" { + p.Scope = in.scope + } session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - err := p.AddProjects(in.projects) - Expect(err).To(BeNil()) - p.SetProjectScope() + p.SetAllowedGroups(in.allowedGroups) - if len(in.groups) > 0 { - p.Groups = in.groups + err := p.SetAllowedProjects(in.allowedProjects) + if in.expectedError == nil { + Expect(err).To(BeNil()) + } else { + Expect(err).To(MatchError(in.expectedError)) + return } + Expect(p.Scope).To(Equal(in.expectedScope)) err = p.EnrichSession(context.Background(), session) - Expect(err).To(BeNil()) - Expect(session.Groups).To(Equal(in.expectedValue)) + Expect(session.Groups).To(Equal(in.expectedGroups)) + + authorized, err := p.Authorize(context.Background(), session) + Expect(err).To(BeNil()) + Expect(authorized).To(Equal(in.expectedAuthz)) }, Entry("project membership valid on group project", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar", "project:my_group/my_project"}, - projects: []string{"my_group/my_project"}, + allowedProjects: []string{"my_group/my_project"}, + expectedAuthz: true, + expectedGroups: []string{"foo", "bar", "project:my_group/my_project"}, + expectedScope: "openid email read_api", }), Entry("project membership invalid on group project, insufficient access level level", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar"}, - projects: []string{"my_group/my_project=40"}, + allowedProjects: []string{"my_group/my_project=40"}, + expectedAuthz: false, + expectedGroups: []string{"foo", "bar"}, + expectedScope: "openid email read_api", }), Entry("project membership invalid on group project, no access at all", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar"}, - projects: []string{"no_access_group/no_access_project=30"}, + allowedProjects: []string{"no_access_group/no_access_project=30"}, + expectedAuthz: false, + expectedGroups: []string{"foo", "bar"}, + expectedScope: "openid email read_api", }), Entry("project membership valid on personnal project", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar", "project:my_profile/my_personal_project"}, - projects: []string{"my_profile/my_personal_project"}, + allowedProjects: []string{"my_profile/my_personal_project"}, + scope: "openid email read_api profile", + expectedAuthz: true, + expectedGroups: []string{"foo", "bar", "project:my_profile/my_personal_project"}, + expectedScope: "openid email read_api profile", }), Entry("project membership invalid on personnal project, insufficient access level", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar"}, - projects: []string{"my_profile/my_personal_project=40"}, + allowedProjects: []string{"my_profile/my_personal_project=40"}, + expectedAuthz: false, + expectedGroups: []string{"foo", "bar"}, + expectedScope: "openid email read_api", }), Entry("project membership invalid", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar"}, - projects: []string{"my_group/my_bad_project"}, + allowedProjects: []string{"my_group/my_bad_project"}, + expectedAuthz: false, + expectedGroups: []string{"foo", "bar"}, + expectedScope: "openid email read_api", }), Entry("group membership valid", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar"}, - groups: []string{"foo"}, + allowedGroups: []string{"foo"}, + expectedGroups: []string{"foo", "bar"}, + expectedAuthz: true, + expectedScope: "openid email", }), Entry("groups and projects", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar", "project:my_group/my_project", "project:my_profile/my_personal_project"}, - groups: []string{"foo", "baz"}, - projects: []string{"my_group/my_project", "my_profile/my_personal_project"}, + allowedGroups: []string{"foo", "baz"}, + allowedProjects: []string{"my_group/my_project", "my_profile/my_personal_project"}, + expectedAuthz: true, + expectedGroups: []string{"foo", "bar", "project:my_group/my_project", "project:my_profile/my_personal_project"}, + expectedScope: "openid email read_api", }), Entry("archived projects", entitiesTableInput{ - expectedValue: []string{"group:foo", "group:bar"}, - groups: []string{}, - projects: []string{"my_group/my_archived_project"}, + allowedProjects: []string{"my_group/my_archived_project"}, + expectedAuthz: false, + expectedGroups: []string{"foo", "bar"}, + expectedScope: "openid email read_api", + }), + Entry("invalid project format", entitiesTableInput{ + allowedProjects: []string{"my_group/my_invalid_project=123"}, + expectedError: errors.New("invalid gitlab project access level specified (my_group/my_invalid_project)"), + expectedScope: "openid email read_api", }), ) - }) - Context("when generating group list from multiple kind", func() { - type entitiesTableInput struct { - projects []string - groups []string - } + Context("when refreshing", func() { + It("keeps the existing nickname after refreshing", func() { + session := &sessions.SessionState{ + User: "nickname", + } + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + s.User = "subject" + return true, nil + } + refreshed, err := p.RefreshSession(context.Background(), session) + Expect(refreshed).To(BeTrue()) + Expect(err).ToNot(HaveOccurred()) + Expect(session.User).To(Equal("nickname")) + }) + It("keeps existing projects after refreshing groups", func() { + session := &sessions.SessionState{} + session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} - DescribeTable("should prefix entities with group kind", func(in entitiesTableInput) { - p.Groups = in.groups - err := p.AddProjects(in.projects) - Expect(err).To(BeNil()) + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + s.Groups = []string{"baz"} + return true, nil + } - all := p.PrefixAllowedGroups() + refreshed, err := p.RefreshSession(context.Background(), session) + Expect(refreshed).To(BeTrue()) + Expect(err).ToNot(HaveOccurred()) + Expect(len(session.Groups)).To(Equal(3)) + Expect(session.Groups). + To(ContainElements([]string{"baz", "project:thing", "project:sample"})) + }) + It("leaves existing groups when not refreshed", func() { + session := &sessions.SessionState{} + session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} - Expect(len(all)).To(Equal(len(in.projects) + len(in.groups))) - }, - Entry("simple test case", entitiesTableInput{ - projects: []string{"my_group/my_project", "my_group/my_other_project"}, - groups: []string{"mygroup", "myothergroup"}, - }), - Entry("projects only", entitiesTableInput{ - projects: []string{"my_group/my_project", "my_group/my_other_project"}, - groups: []string{}, - }), - Entry("groups only", entitiesTableInput{ - projects: []string{}, - groups: []string{"mygroup", "myothergroup"}, - }), - ) + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + return false, nil + } + + refreshed, err := p.RefreshSession(context.Background(), session) + Expect(refreshed).To(BeFalse()) + Expect(err).ToNot(HaveOccurred()) + Expect(len(session.Groups)).To(Equal(4)) + Expect(session.Groups). + To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"})) + }) + It("leaves existing groups when OIDC refresh errors", func() { + session := &sessions.SessionState{} + session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} + + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + return false, errors.New("failure") + } + + refreshed, err := p.RefreshSession(context.Background(), session) + Expect(refreshed).To(BeFalse()) + Expect(err).To(HaveOccurred()) + Expect(len(session.Groups)).To(Equal(4)) + Expect(session.Groups). + To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"})) + }) }) }) diff --git a/providers/provider_data.go b/providers/provider_data.go index ccd6e47f..de2aae3e 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -41,6 +41,7 @@ type ProviderData struct { // Common OIDC options for any OIDC-based providers to consume AllowUnverifiedEmail bool + UserClaim string EmailClaim string GroupsClaim string Verifier *oidc.IDTokenVerifier @@ -156,6 +157,17 @@ func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions. ss.Email = claims.Email ss.Groups = claims.Groups + // Allow specialized providers that embed OIDCProvider to control the User + // claim. Not exposed as a configuration flag to generic OIDC provider + // users (yet). + if p.UserClaim != "" { + user, ok := claims.raw[p.UserClaim].(string) + if !ok { + return nil, fmt.Errorf("unable to extract custom UserClaim (%s)", p.UserClaim) + } + ss.User = user + } + // TODO (@NickMeves) Deprecate for dynamic claim to session mapping if pref, ok := claims.raw["preferred_username"].(string); ok { ss.PreferredUsername = pref diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index 09a19ded..27484165 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -211,6 +211,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { testCases := map[string]struct { IDToken idTokenClaims AllowUnverified bool + UserClaim string EmailClaim string GroupsClaim string ExpectedError error @@ -259,6 +260,27 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { PreferredUsername: "Complex Claim", }, }, + "User Claim Switched": { + IDToken: defaultIDToken, + AllowUnverified: true, + UserClaim: "phone_number", + EmailClaim: "email", + GroupsClaim: "groups", + ExpectedSession: &sessions.SessionState{ + User: "+4798765432", + Email: "janed@me.com", + Groups: []string{"test:a", "test:b"}, + PreferredUsername: "Jane Dobbs", + }, + }, + "User Claim Invalid": { + IDToken: defaultIDToken, + AllowUnverified: true, + UserClaim: "groups", + EmailClaim: "email", + GroupsClaim: "groups", + ExpectedError: errors.New("unable to extract custom UserClaim (groups)"), + }, "Email Claim Switched": { IDToken: unverifiedIDToken, AllowUnverified: true, @@ -332,6 +354,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { ), } provider.AllowUnverifiedEmail = tc.AllowUnverified + provider.UserClaim = tc.UserClaim provider.EmailClaim = tc.EmailClaim provider.GroupsClaim = tc.GroupsClaim