Preserve projects after `RefreshSession`
RefreshSession will override session.Groups with the new `groups` claims. We need to preserve all `project:` prefixed groups and reattach them post refresh.
This commit is contained in:
		
							parent
							
								
									11c2177f18
								
							
						
					
					
						commit
						95f9de5979
					
				|  | @ -13,9 +13,10 @@ import ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	gitlabProviderName = "GitLab" | ||||
| 	gitlabDefaultScope = "openid email" | ||||
| 	gitlabUserClaim    = "nickname" | ||||
| 	gitlabProviderName  = "GitLab" | ||||
| 	gitlabDefaultScope  = "openid email" | ||||
| 	gitlabUserClaim     = "nickname" | ||||
| 	gitlabProjectPrefix = "project:" | ||||
| ) | ||||
| 
 | ||||
| // GitLabProvider represents a GitLab based Identity Provider
 | ||||
|  | @ -23,6 +24,8 @@ type GitLabProvider struct { | |||
| 	*OIDCProvider | ||||
| 
 | ||||
| 	allowedProjects []*gitlabProject | ||||
| 	// Expose this for unit testing
 | ||||
| 	oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error) | ||||
| } | ||||
| 
 | ||||
| var _ Provider = (*GitLabProvider)(nil) | ||||
|  | @ -35,11 +38,14 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider { | |||
| 		p.Scope = gitlabDefaultScope | ||||
| 	} | ||||
| 
 | ||||
| 	oidcProvider := &OIDCProvider{ | ||||
| 		ProviderData: p, | ||||
| 		SkipNonce:    false, | ||||
| 	} | ||||
| 
 | ||||
| 	return &GitLabProvider{ | ||||
| 		OIDCProvider: &OIDCProvider{ | ||||
| 			ProviderData: p, | ||||
| 			SkipNonce:    false, | ||||
| 		}, | ||||
| 		OIDCProvider:    oidcProvider, | ||||
| 		oidcRefreshFunc: oidcProvider.RefreshSession, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -245,5 +251,41 @@ func (p *GitLabProvider) getProjectInfo(ctx context.Context, s *sessions.Session | |||
| } | ||||
| 
 | ||||
| func formatProject(project *gitlabProject) string { | ||||
| 	return fmt.Sprintf("project:%s", project.Name) | ||||
| 	return gitlabProjectPrefix + project.Name | ||||
| } | ||||
| 
 | ||||
| // RefreshSession refreshes the session with the OIDCProvider implementation
 | ||||
| // but preserves the custom GitLab projects added in the `EnrichSession` stage.
 | ||||
| func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 	projects := getSessionProjects(s) | ||||
| 	// This will overwrite s.Groups with the new IDToken's `groups` claims
 | ||||
| 	refreshed, err := p.oidcRefreshFunc(ctx, s) | ||||
| 	if refreshed && err == nil { | ||||
| 		s.Groups = append(s.Groups, projects...) | ||||
| 		s.Groups = deduplicateGroups(s.Groups) | ||||
| 	} | ||||
| 	return refreshed, err | ||||
| } | ||||
| 
 | ||||
| func getSessionProjects(s *sessions.SessionState) []string { | ||||
| 	var projects []string | ||||
| 	for _, group := range s.Groups { | ||||
| 		if strings.HasPrefix(group, gitlabProjectPrefix) { | ||||
| 			projects = append(projects, group) | ||||
| 		} | ||||
| 	} | ||||
| 	return projects | ||||
| } | ||||
| 
 | ||||
| func deduplicateGroups(groups []string) []string { | ||||
| 	groupSet := make(map[string]struct{}) | ||||
| 	for _, group := range groups { | ||||
| 		groupSet[group] = struct{}{} | ||||
| 	} | ||||
| 
 | ||||
| 	uniqueGroups := make([]string, 0, len(groupSet)) | ||||
| 	for group := range groupSet { | ||||
| 		uniqueGroups = append(uniqueGroups, group) | ||||
| 	} | ||||
| 	return uniqueGroups | ||||
| } | ||||
|  |  | |||
|  | @ -307,4 +307,55 @@ var _ = Describe("Gitlab Provider Tests", func() { | |||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("when refreshing", func() { | ||||
| 		It("keeps existing projects after refreshing groups", func() { | ||||
| 			session := &sessions.SessionState{} | ||||
| 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | ||||
| 
 | ||||
| 			p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 				s.Groups = []string{"baz"} | ||||
| 				return true, nil | ||||
| 			} | ||||
| 
 | ||||
| 			refreshed, err := p.RefreshSession(context.Background(), session) | ||||
| 			Expect(refreshed).To(BeTrue()) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(len(session.Groups)).To(Equal(3)) | ||||
| 			Expect(session.Groups). | ||||
| 				To(ContainElements([]string{"baz", "project:thing", "project:sample"})) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("leaves existing groups when not refreshed", func() { | ||||
| 			session := &sessions.SessionState{} | ||||
| 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | ||||
| 
 | ||||
| 			p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 				return false, nil | ||||
| 			} | ||||
| 
 | ||||
| 			refreshed, err := p.RefreshSession(context.Background(), session) | ||||
| 			Expect(refreshed).To(BeFalse()) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(len(session.Groups)).To(Equal(4)) | ||||
| 			Expect(session.Groups). | ||||
| 				To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"})) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("leaves existing groups when OIDC refresh errors", func() { | ||||
| 			session := &sessions.SessionState{} | ||||
| 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | ||||
| 
 | ||||
| 			p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 				return false, errors.New("failure") | ||||
| 			} | ||||
| 
 | ||||
| 			refreshed, err := p.RefreshSession(context.Background(), session) | ||||
| 			Expect(refreshed).To(BeFalse()) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 			Expect(len(session.Groups)).To(Equal(4)) | ||||
| 			Expect(session.Groups). | ||||
| 				To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"})) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue