Preserve Nickname around refreshes
This commit is contained in:
		
							parent
							
								
									95f9de5979
								
							
						
					
					
						commit
						e4a8c98e1b
					
				|  | @ -15,7 +15,6 @@ import ( | ||||||
| const ( | const ( | ||||||
| 	gitlabProviderName  = "GitLab" | 	gitlabProviderName  = "GitLab" | ||||||
| 	gitlabDefaultScope  = "openid email" | 	gitlabDefaultScope  = "openid email" | ||||||
| 	gitlabUserClaim     = "nickname" |  | ||||||
| 	gitlabProjectPrefix = "project:" | 	gitlabProjectPrefix = "project:" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -33,7 +32,6 @@ var _ Provider = (*GitLabProvider)(nil) | ||||||
| // NewGitLabProvider initiates a new GitLabProvider
 | // NewGitLabProvider initiates a new GitLabProvider
 | ||||||
| func NewGitLabProvider(p *ProviderData) *GitLabProvider { | func NewGitLabProvider(p *ProviderData) *GitLabProvider { | ||||||
| 	p.ProviderName = gitlabProviderName | 	p.ProviderName = gitlabProviderName | ||||||
| 	p.UserClaim = gitlabUserClaim |  | ||||||
| 	if p.Scope == "" { | 	if p.Scope == "" { | ||||||
| 		p.Scope = gitlabDefaultScope | 		p.Scope = gitlabDefaultScope | ||||||
| 	} | 	} | ||||||
|  | @ -257,10 +255,13 @@ func formatProject(project *gitlabProject) string { | ||||||
| // RefreshSession refreshes the session with the OIDCProvider implementation
 | // RefreshSession refreshes the session with the OIDCProvider implementation
 | ||||||
| // but preserves the custom GitLab projects added in the `EnrichSession` stage.
 | // but preserves the custom GitLab projects added in the `EnrichSession` stage.
 | ||||||
| func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | 	nickname := s.User | ||||||
| 	projects := getSessionProjects(s) | 	projects := getSessionProjects(s) | ||||||
| 	// This will overwrite s.Groups with the new IDToken's `groups` claims
 | 	// This will overwrite s.Groups with the new IDToken's `groups` claims
 | ||||||
|  | 	// and s.User with the `sub` claim.
 | ||||||
| 	refreshed, err := p.oidcRefreshFunc(ctx, s) | 	refreshed, err := p.oidcRefreshFunc(ctx, s) | ||||||
| 	if refreshed && err == nil { | 	if refreshed && err == nil { | ||||||
|  | 		s.User = nickname | ||||||
| 		s.Groups = append(s.Groups, projects...) | 		s.Groups = append(s.Groups, projects...) | ||||||
| 		s.Groups = deduplicateGroups(s.Groups) | 		s.Groups = deduplicateGroups(s.Groups) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -309,6 +309,19 @@ var _ = Describe("Gitlab Provider Tests", func() { | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	Context("when refreshing", func() { | 	Context("when refreshing", func() { | ||||||
|  | 		It("keeps the existing nickname after refreshing", func() { | ||||||
|  | 			session := &sessions.SessionState{ | ||||||
|  | 				User: "nickname", | ||||||
|  | 			} | ||||||
|  | 			p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | 				s.User = "subject" | ||||||
|  | 				return true, nil | ||||||
|  | 			} | ||||||
|  | 			refreshed, err := p.RefreshSession(context.Background(), session) | ||||||
|  | 			Expect(refreshed).To(BeTrue()) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(session.User).To(Equal("nickname")) | ||||||
|  | 		}) | ||||||
| 		It("keeps existing projects after refreshing groups", func() { | 		It("keeps existing projects after refreshing groups", func() { | ||||||
| 			session := &sessions.SessionState{} | 			session := &sessions.SessionState{} | ||||||
| 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | ||||||
|  | @ -325,7 +338,6 @@ var _ = Describe("Gitlab Provider Tests", func() { | ||||||
| 			Expect(session.Groups). | 			Expect(session.Groups). | ||||||
| 				To(ContainElements([]string{"baz", "project:thing", "project:sample"})) | 				To(ContainElements([]string{"baz", "project:thing", "project:sample"})) | ||||||
| 		}) | 		}) | ||||||
| 
 |  | ||||||
| 		It("leaves existing groups when not refreshed", func() { | 		It("leaves existing groups when not refreshed", func() { | ||||||
| 			session := &sessions.SessionState{} | 			session := &sessions.SessionState{} | ||||||
| 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | ||||||
|  | @ -341,7 +353,6 @@ var _ = Describe("Gitlab Provider Tests", func() { | ||||||
| 			Expect(session.Groups). | 			Expect(session.Groups). | ||||||
| 				To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"})) | 				To(ContainElements([]string{"foo", "bar", "project:thing", "project:sample"})) | ||||||
| 		}) | 		}) | ||||||
| 
 |  | ||||||
| 		It("leaves existing groups when OIDC refresh errors", func() { | 		It("leaves existing groups when OIDC refresh errors", func() { | ||||||
| 			session := &sessions.SessionState{} | 			session := &sessions.SessionState{} | ||||||
| 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | 			session.Groups = []string{"foo", "bar", "project:thing", "project:sample"} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue