Refactor OIDC to EnrichSession
This commit is contained in:
		
							parent
							
								
									4fda907830
								
							
						
					
					
						commit
						a1877434b2
					
				|  | @ -274,7 +274,7 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | ||||||
| 		p.SetRepository(o.BitbucketRepository) | 		p.SetRepository(o.BitbucketRepository) | ||||||
| 	case *providers.OIDCProvider: | 	case *providers.OIDCProvider: | ||||||
| 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | ||||||
| 		p.UserIDClaim = o.UserIDClaim | 		p.EmailClaim = o.UserIDClaim | ||||||
| 		p.GroupsClaim = o.OIDCGroupsClaim | 		p.GroupsClaim = o.OIDCGroupsClaim | ||||||
| 		if p.Verifier == nil { | 		if p.Verifier == nil { | ||||||
| 			msgs = append(msgs, "oidc provider requires an oidc issuer URL") | 			msgs = append(msgs, "oidc provider requires an oidc issuer URL") | ||||||
|  |  | ||||||
|  | @ -2,18 +2,17 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	"github.com/coreos/go-oidc" | ||||||
| 	"golang.org/x/oauth2" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" | ||||||
|  | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const emailClaim = "email" | const emailClaim = "email" | ||||||
|  | @ -23,7 +22,7 @@ type OIDCProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| 
 | 
 | ||||||
| 	AllowUnverifiedEmail bool | 	AllowUnverifiedEmail bool | ||||||
| 	UserIDClaim          string | 	EmailClaim           string | ||||||
| 	GroupsClaim          string | 	GroupsClaim          string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -36,10 +35,10 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { | ||||||
| var _ Provider = (*OIDCProvider)(nil) | var _ Provider = (*OIDCProvider)(nil) | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||||
| 	clientSecret, err := p.GetClientSecret() | 	clientSecret, err := p.GetClientSecret() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	c := oauth2.Config{ | 	c := oauth2.Config{ | ||||||
|  | @ -52,23 +51,74 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 	} | 	} | ||||||
| 	token, err := c.Exchange(ctx, code) | 	token, err := c.Exchange(ctx, code) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("token exchange: %v", err) | 		return nil, fmt.Errorf("token exchange failure: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// in the initial exchange the id token is mandatory
 | 	return p.createSession(ctx, token, false) | ||||||
| 	idToken, err := p.findVerifiedIDToken(ctx, token) | } | ||||||
|  | 
 | ||||||
|  | // EnrichSessionState is called after Redeem to allow providers to enrich session fields
 | ||||||
|  | // such as User, Email, Groups with provider specific API calls.
 | ||||||
|  | func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { | ||||||
|  | 	if p.ProfileURL.String() == "" { | ||||||
|  | 		if s.Email == "" { | ||||||
|  | 			return errors.New("id_token did not contain an email and profileURL is not defined") | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Try to get missing emails or groups from a profileURL
 | ||||||
|  | 	if s.Email == "" || len(s.Groups) == 0 { | ||||||
|  | 		err := p.callProfileURL(ctx, s) | ||||||
|  | 		if err != nil { | ||||||
|  | 			logger.Errorf("Warning: Profile URL request failed: %v", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// If a mandatory email wasn't set, error at this point.
 | ||||||
|  | 	if s.Email == "" { | ||||||
|  | 		return errors.New("neither the id_token nor the profileURL set an email") | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // callProfileURL enriches a session's Email & Groups via the JSON response of
 | ||||||
|  | // an OIDC profile URL
 | ||||||
|  | func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionState) error { | ||||||
|  | 	respJSON, err := requests.New(p.ProfileURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(makeOIDCHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("could not verify id_token: %v", err) | 		return err | ||||||
| 	} else if idToken == nil { |  | ||||||
| 		return nil, fmt.Errorf("token response did not contain an id_token") |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	s, err = p.createSessionState(ctx, token, idToken) | 	email, err := respJSON.Get(p.EmailClaim).String() | ||||||
| 	if err != nil { | 	if err == nil && s.Email == "" { | ||||||
| 		return nil, fmt.Errorf("unable to update session: %v", err) | 		s.Email = email | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return | 	// Handle array & singleton groups cases
 | ||||||
|  | 	if len(s.Groups) == 0 { | ||||||
|  | 		groups, err := respJSON.Get(p.GroupsClaim).StringArray() | ||||||
|  | 		if err == nil { | ||||||
|  | 			s.Groups = groups | ||||||
|  | 		} else { | ||||||
|  | 			group, err := respJSON.Get(p.GroupsClaim).String() | ||||||
|  | 			if err == nil { | ||||||
|  | 				s.Groups = []string{group} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ValidateSessionState checks that the session's IDToken is still valid
 | ||||||
|  | func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | ||||||
|  | 	_, err := p.Verifier.Verify(ctx, s.IDToken) | ||||||
|  | 	return err == nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
|  | @ -83,14 +133,16 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S | ||||||
| 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn) | 	logger.Printf("refreshed session: %s", s) | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { | // redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
 | ||||||
|  | // Access Token and (probably) the ID Token.
 | ||||||
|  | func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { | ||||||
| 	clientSecret, err := p.GetClientSecret() | 	clientSecret, err := p.GetClientSecret() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	c := oauth2.Config{ | 	c := oauth2.Config{ | ||||||
|  | @ -109,19 +161,14 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi | ||||||
| 		return fmt.Errorf("failed to get token: %v", err) | 		return fmt.Errorf("failed to get token: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// in the token refresh response the id_token is optional
 | 	newSession, err := p.createSession(ctx, token, true) | ||||||
| 	idToken, err := p.findVerifiedIDToken(ctx, token) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("unable to extract id_token from response: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	newSession, err := p.createSessionState(ctx, token, idToken) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("unable create new session state from response: %v", err) | 		return fmt.Errorf("unable create new session state from response: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// It's possible that if the refresh token isn't in the token response the session will not contain an id token
 | 	// It's possible that if the refresh token isn't in the token response the
 | ||||||
| 	// if it doesn't it's probably better to retain the old one
 | 	// session will not contain an id token.
 | ||||||
|  | 	// If it doesn't it's probably better to retain the old one
 | ||||||
| 	if newSession.IDToken != "" { | 	if newSession.IDToken != "" { | ||||||
| 		s.IDToken = newSession.IDToken | 		s.IDToken = newSession.IDToken | ||||||
| 		s.Email = newSession.Email | 		s.Email = newSession.Email | ||||||
|  | @ -135,102 +182,113 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi | ||||||
| 	s.CreatedAt = newSession.CreatedAt | 	s.CreatedAt = newSession.CreatedAt | ||||||
| 	s.ExpiresOn = newSession.ExpiresOn | 	s.ExpiresOn = newSession.ExpiresOn | ||||||
| 
 | 
 | ||||||
| 	return | 	return nil | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { |  | ||||||
| 
 |  | ||||||
| 	getIDToken := func() (string, bool) { |  | ||||||
| 		rawIDToken, _ := token.Extra("id_token").(string) |  | ||||||
| 		return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0 |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if rawIDToken, present := getIDToken(); present { |  | ||||||
| 		verifiedIDToken, err := p.Verifier.Verify(ctx, rawIDToken) |  | ||||||
| 		return verifiedIDToken, err |  | ||||||
| 	} |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { |  | ||||||
| 
 |  | ||||||
| 	var newSession *sessions.SessionState |  | ||||||
| 
 |  | ||||||
| 	if idToken == nil { |  | ||||||
| 		newSession = &sessions.SessionState{} |  | ||||||
| 	} else { |  | ||||||
| 		var err error |  | ||||||
| 		newSession, err = p.createSessionStateInternal(ctx, idToken, token) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	created := time.Now() |  | ||||||
| 	newSession.AccessToken = token.AccessToken |  | ||||||
| 	newSession.RefreshToken = token.RefreshToken |  | ||||||
| 	newSession.CreatedAt = &created |  | ||||||
| 	newSession.ExpiresOn = &token.Expiry |  | ||||||
| 	return newSession, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CreateSessionFromToken converts Bearer IDTokens into sessions
 | ||||||
| func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | ||||||
| 	idToken, err := p.Verifier.Verify(ctx, token) | 	idToken, err := p.Verifier.Verify(ctx, token) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	newSession, err := p.createSessionStateInternal(ctx, idToken, nil) | 	ss, err := p.buildSessionFromClaims(idToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	newSession.AccessToken = token | 	// Allow empty Email in Bearer case since we can't hit the ProfileURL
 | ||||||
| 	newSession.IDToken = token | 	if ss.Email == "" { | ||||||
| 	newSession.RefreshToken = "" | 		ss.Email = ss.User | ||||||
| 	newSession.ExpiresOn = &idToken.Expiry |  | ||||||
| 
 |  | ||||||
| 	return newSession, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { |  | ||||||
| 
 |  | ||||||
| 	newSession := &sessions.SessionState{} |  | ||||||
| 
 |  | ||||||
| 	if idToken == nil { |  | ||||||
| 		return newSession, nil |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	claims, err := p.findClaimsFromIDToken(ctx, idToken, token) | 	ss.AccessToken = token | ||||||
|  | 	ss.IDToken = token | ||||||
|  | 	ss.RefreshToken = "" | ||||||
|  | 	ss.ExpiresOn = &idToken.Expiry | ||||||
|  | 
 | ||||||
|  | 	return ss, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // createSession takes an oauth2.Token and creates a SessionState from it.
 | ||||||
|  | // It alters behavior if called from Redeem vs Refresh
 | ||||||
|  | func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { | ||||||
|  | 	idToken, err := p.findVerifiedIDToken(ctx, token) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not verify id_token: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// IDToken is mandatory in Redeem but optional in Refresh
 | ||||||
|  | 	if idToken == nil && !refresh { | ||||||
|  | 		return nil, errors.New("token response did not contain an id_token") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ss, err := p.buildSessionFromClaims(idToken) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ss.AccessToken = token.AccessToken | ||||||
|  | 	ss.RefreshToken = token.RefreshToken | ||||||
|  | 	ss.IDToken = getIDToken(token) | ||||||
|  | 
 | ||||||
|  | 	created := time.Now() | ||||||
|  | 	ss.CreatedAt = &created | ||||||
|  | 	ss.ExpiresOn = &token.Expiry | ||||||
|  | 
 | ||||||
|  | 	return ss, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { | ||||||
|  | 	rawIDToken := getIDToken(token) | ||||||
|  | 	if strings.TrimSpace(rawIDToken) != "" { | ||||||
|  | 		return p.Verifier.Verify(ctx, rawIDToken) | ||||||
|  | 	} | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
 | ||||||
|  | // with non-Token related fields.
 | ||||||
|  | func (p *OIDCProvider) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { | ||||||
|  | 	ss := &sessions.SessionState{} | ||||||
|  | 
 | ||||||
|  | 	if idToken == nil { | ||||||
|  | 		return ss, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	claims, err := p.getClaims(idToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) | 		return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if token != nil { | 	ss.User = claims.Subject | ||||||
| 		newSession.IDToken = token.Extra("id_token").(string) | 	ss.Email = claims.Email | ||||||
|  | 	ss.Groups = claims.Groups | ||||||
|  | 
 | ||||||
|  | 	// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
 | ||||||
|  | 	if pref, ok := claims.rawClaims["preferred_username"].(string); ok { | ||||||
|  | 		ss.PreferredUsername = pref | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future
 | 	verifyEmail := (p.EmailClaim == emailClaim) && !p.AllowUnverifiedEmail | ||||||
| 
 |  | ||||||
| 	newSession.User = claims.Subject |  | ||||||
| 	newSession.Groups = claims.Groups |  | ||||||
| 	newSession.PreferredUsername = claims.PreferredUsername |  | ||||||
| 
 |  | ||||||
| 	verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail |  | ||||||
| 	if verifyEmail && claims.Verified != nil && !*claims.Verified { | 	if verifyEmail && claims.Verified != nil && !*claims.Verified { | ||||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.UserID) | 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return newSession, nil | 	return ss, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState checks that the session's IDToken is still valid
 | type OIDCClaims struct { | ||||||
| func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | 	Subject  string   `json:"sub"` | ||||||
| 	_, err := p.Verifier.Verify(ctx, s.IDToken) | 	Email    string   `json:"-"` | ||||||
| 	return err == nil | 	Groups   []string `json:"-"` | ||||||
|  | 	Verified *bool    `json:"email_verified"` | ||||||
|  | 
 | ||||||
|  | 	rawClaims map[string]interface{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { | // getClaims extracts IDToken claims into an OIDCClaims
 | ||||||
|  | func (p *OIDCProvider) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { | ||||||
| 	claims := &OIDCClaims{} | 	claims := &OIDCClaims{} | ||||||
| 
 | 
 | ||||||
| 	// Extract default claims.
 | 	// Extract default claims.
 | ||||||
|  | @ -242,86 +300,28 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | ||||||
| 		return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) | 		return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	userID := claims.rawClaims[p.UserIDClaim] | 	email := claims.rawClaims[p.EmailClaim] | ||||||
| 	if userID != nil { | 	if email != nil { | ||||||
| 		claims.UserID = fmt.Sprint(userID) | 		claims.Email = fmt.Sprint(email) | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims) |  | ||||||
| 
 |  | ||||||
| 	// userID claim was not present or was empty in the ID Token
 |  | ||||||
| 	if claims.UserID == "" { |  | ||||||
| 		// BearerToken case, allow empty UserID
 |  | ||||||
| 		// ProfileURL checks below won't work since we don't have an access token
 |  | ||||||
| 		if token == nil { |  | ||||||
| 			claims.UserID = claims.Subject |  | ||||||
| 			return claims, nil |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		profileURL := p.ProfileURL.String() |  | ||||||
| 		if profileURL == "" || token.AccessToken == "" { |  | ||||||
| 			return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
 |  | ||||||
| 		// contents at the profileURL contains the email.
 |  | ||||||
| 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 |  | ||||||
| 		respJSON, err := requests.New(profileURL). |  | ||||||
| 			WithContext(ctx). |  | ||||||
| 			WithHeaders(makeOIDCHeader(token.AccessToken)). |  | ||||||
| 			Do(). |  | ||||||
| 			UnmarshalJSON() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		userID, err := respJSON.Get(p.UserIDClaim).String() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained user ID claim (%q)", p.UserIDClaim) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		claims.UserID = userID |  | ||||||
| 	} | 	} | ||||||
|  | 	claims.Groups = p.extractGroups(claims.rawClaims) | ||||||
| 
 | 
 | ||||||
| 	return claims, nil | 	return claims, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string { | func (p *OIDCProvider) extractGroups(claims map[string]interface{}) []string { | ||||||
| 	groups := []string{} | 	groups := []string{} | ||||||
| 
 | 	rawGroups, ok := claims[p.GroupsClaim].([]interface{}) | ||||||
| 	rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{}) |  | ||||||
| 	if rawGroups != nil && ok { | 	if rawGroups != nil && ok { | ||||||
| 		for _, rawGroup := range rawGroups { | 		for _, rawGroup := range rawGroups { | ||||||
| 			formattedGroup, err := formatGroup(rawGroup) | 			formattedGroup, err := formatGroup(rawGroup) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.Errorf("unable to format group of type %s with error %s", reflect.TypeOf(rawGroup), err) | 				logger.Errorf("Warning: unable to format group of type %s with error %s", | ||||||
|  | 					reflect.TypeOf(rawGroup), err) | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			groups = append(groups, formattedGroup) | 			groups = append(groups, formattedGroup) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	return groups | 	return groups | ||||||
| 
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func formatGroup(rawGroup interface{}) (string, error) { |  | ||||||
| 	group, ok := rawGroup.(string) |  | ||||||
| 	if !ok { |  | ||||||
| 		jsonGroup, err := json.Marshal(rawGroup) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return "", err |  | ||||||
| 		} |  | ||||||
| 		group = string(jsonGroup) |  | ||||||
| 	} |  | ||||||
| 	return group, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type OIDCClaims struct { |  | ||||||
| 	rawClaims         map[string]interface{} |  | ||||||
| 	UserID            string |  | ||||||
| 	Subject           string   `json:"sub"` |  | ||||||
| 	Verified          *bool    `json:"email_verified"` |  | ||||||
| 	PreferredUsername string   `json:"preferred_username"` |  | ||||||
| 	Groups            []string `json:"-"` |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -154,7 +154,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { | ||||||
| 
 | 
 | ||||||
| 	p := &OIDCProvider{ | 	p := &OIDCProvider{ | ||||||
| 		ProviderData: providerData, | 		ProviderData: providerData, | ||||||
| 		UserIDClaim:  "email", | 		EmailClaim:   "email", | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return p | 	return p | ||||||
|  | @ -225,7 +225,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	server, provider := newTestSetup(body) | 	server, provider := newTestSetup(body) | ||||||
| 	provider.UserIDClaim = "phone_number" | 	provider.EmailClaim = "phone_number" | ||||||
| 	defer server.Close() | 	defer server.Close() | ||||||
| 
 | 
 | ||||||
| 	session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") | 	session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") | ||||||
|  | @ -233,6 +233,256 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { | ||||||
| 	assert.Equal(t, defaultIDToken.Phone, session.Email) | 	assert.Equal(t, defaultIDToken.Phone, session.Email) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestOIDCProvider_EnrichSession(t *testing.T) { | ||||||
|  | 	const ( | ||||||
|  | 		idToken      = "Unchanged ID Token" | ||||||
|  | 		accessToken  = "Unchanged Access Token" | ||||||
|  | 		refreshToken = "Unchanged Refresh Token" | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	testCases := map[string]struct { | ||||||
|  | 		ExistingSession *sessions.SessionState | ||||||
|  | 		EmailClaim      string | ||||||
|  | 		GroupsClaim     string | ||||||
|  | 		ProfileJSON     map[string]interface{} | ||||||
|  | 		ExpectedError   error | ||||||
|  | 		ExpectedSession *sessions.SessionState | ||||||
|  | 	}{ | ||||||
|  | 		"Already Populated": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				Groups:       []string{"already", "populated"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "email", | ||||||
|  | 			GroupsClaim: "groups", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"email":  "new@thing.com", | ||||||
|  | 				"groups": []string{"new", "thing"}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: nil, | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				Groups:       []string{"already", "populated"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"Missing Email": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "missing.email", | ||||||
|  | 				Groups:       []string{"already", "populated"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "email", | ||||||
|  | 			GroupsClaim: "groups", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"email":  "found@email.com", | ||||||
|  | 				"groups": []string{"new", "thing"}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: nil, | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "missing.email", | ||||||
|  | 				Email:        "found@email.com", | ||||||
|  | 				Groups:       []string{"already", "populated"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 
 | ||||||
|  | 		"Missing Email Only in Profile URL": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "missing.email", | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "email", | ||||||
|  | 			GroupsClaim: "groups", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"email": "found@email.com", | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: nil, | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "missing.email", | ||||||
|  | 				Email:        "found@email.com", | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"Missing Email with Custom Claim": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "missing.email", | ||||||
|  | 				Groups:       []string{"already", "populated"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "weird", | ||||||
|  | 			GroupsClaim: "groups", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"weird":  "weird@claim.com", | ||||||
|  | 				"groups": []string{"new", "thing"}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: nil, | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "missing.email", | ||||||
|  | 				Email:        "weird@claim.com", | ||||||
|  | 				Groups:       []string{"already", "populated"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"Missing Email not in Profile URL": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "missing.email", | ||||||
|  | 				Groups:       []string{"already", "populated"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "email", | ||||||
|  | 			GroupsClaim: "groups", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"groups": []string{"new", "thing"}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: errors.New("neither the id_token nor the profileURL set an email"), | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "missing.email", | ||||||
|  | 				Groups:       []string{"already", "populated"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"Missing Groups": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				Groups:       []string{}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "email", | ||||||
|  | 			GroupsClaim: "groups", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"email":  "new@thing.com", | ||||||
|  | 				"groups": []string{"new", "thing"}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: nil, | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				Groups:       []string{"new", "thing"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"Missing Groups with Custom Claim": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				Groups:       nil, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "email", | ||||||
|  | 			GroupsClaim: "roles", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"email": "new@thing.com", | ||||||
|  | 				"roles": []string{"new", "thing", "roles"}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: nil, | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				Groups:       []string{"new", "thing", "roles"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"Missing Groups String Profile URL Response": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				Groups:       []string{}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "email", | ||||||
|  | 			GroupsClaim: "groups", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"email":  "new@thing.com", | ||||||
|  | 				"groups": "singleton", | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: nil, | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				Groups:       []string{"singleton"}, | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"Missing Groups in both Claims and Profile URL": { | ||||||
|  | 			ExistingSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 			EmailClaim:  "email", | ||||||
|  | 			GroupsClaim: "groups", | ||||||
|  | 			ProfileJSON: map[string]interface{}{ | ||||||
|  | 				"email": "new@thing.com", | ||||||
|  | 			}, | ||||||
|  | 			ExpectedError: nil, | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:         "already", | ||||||
|  | 				Email:        "already@populated.com", | ||||||
|  | 				IDToken:      idToken, | ||||||
|  | 				AccessToken:  accessToken, | ||||||
|  | 				RefreshToken: refreshToken, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for testName, tc := range testCases { | ||||||
|  | 		t.Run(testName, func(t *testing.T) { | ||||||
|  | 			jsonResp, err := json.Marshal(tc.ProfileJSON) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 			server, provider := newTestSetup(jsonResp) | ||||||
|  | 			provider.ProfileURL, err = url.Parse(server.URL) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 			provider.EmailClaim = tc.EmailClaim | ||||||
|  | 			provider.GroupsClaim = tc.GroupsClaim | ||||||
|  | 			defer server.Close() | ||||||
|  | 
 | ||||||
|  | 			err = provider.EnrichSession(context.Background(), tc.ExistingSession) | ||||||
|  | 			assert.Equal(t, tc.ExpectedError, err) | ||||||
|  | 			assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { | func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	idToken, _ := newSignedTestIDToken(defaultIDToken) | 	idToken, _ := newSignedTestIDToken(defaultIDToken) | ||||||
|  | @ -361,7 +611,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { | func TestOIDCProvider_findVerifiedIDToken(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	server, provider := newTestSetup([]byte("")) | 	server, provider := newTestSetup([]byte("")) | ||||||
| 
 | 
 | ||||||
|  | @ -397,31 +647,3 @@ func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, true, verifiedIDToken == nil) | 	assert.Equal(t, true, verifiedIDToken == nil) | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func Test_formatGroup(t *testing.T) { |  | ||||||
| 	testCases := map[string]struct { |  | ||||||
| 		RawGroup                    interface{} |  | ||||||
| 		ExpectedFormattedGroupValue string |  | ||||||
| 	}{ |  | ||||||
| 		"String Group": { |  | ||||||
| 			RawGroup:                    "group", |  | ||||||
| 			ExpectedFormattedGroupValue: "group", |  | ||||||
| 		}, |  | ||||||
| 		"Map Group": { |  | ||||||
| 			RawGroup:                    map[string]string{"id": "1", "name": "Test"}, |  | ||||||
| 			ExpectedFormattedGroupValue: "{\"id\":\"1\",\"name\":\"Test\"}", |  | ||||||
| 		}, |  | ||||||
| 		"List Group": { |  | ||||||
| 			RawGroup:                    []string{"First", "Second"}, |  | ||||||
| 			ExpectedFormattedGroupValue: "[\"First\",\"Second\"]", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for testName, tc := range testCases { |  | ||||||
| 		t.Run(testName, func(t *testing.T) { |  | ||||||
| 			formattedGroup, err := formatGroup(tc.RawGroup) |  | ||||||
| 			assert.Nil(t, err) |  | ||||||
| 			assert.Equal(t, tc.ExpectedFormattedGroupValue, formattedGroup) |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -1,9 +1,12 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 
 | ||||||
|  | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | @ -55,3 +58,23 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va | ||||||
| 	a.RawQuery = params.Encode() | 	a.RawQuery = params.Encode() | ||||||
| 	return a | 	return a | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func getIDToken(token *oauth2.Token) string { | ||||||
|  | 	idToken, ok := token.Extra("id_token").(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return idToken | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func formatGroup(rawGroup interface{}) (string, error) { | ||||||
|  | 	group, ok := rawGroup.(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		jsonGroup, err := json.Marshal(rawGroup) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  | 		group = string(jsonGroup) | ||||||
|  | 	} | ||||||
|  | 	return group, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -5,9 +5,10 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestMakeAuhtorizationHeader(t *testing.T) { | func Test_makeAuthorizationHeader(t *testing.T) { | ||||||
| 	testCases := []struct { | 	testCases := []struct { | ||||||
| 		name         string | 		name         string | ||||||
| 		prefix       string | 		prefix       string | ||||||
|  | @ -64,3 +65,49 @@ func TestMakeAuhtorizationHeader(t *testing.T) { | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func Test_getIDToken(t *testing.T) { | ||||||
|  | 	const idToken = "eyJfoobar.eyJfoobar.12345asdf" | ||||||
|  | 	g := NewWithT(t) | ||||||
|  | 
 | ||||||
|  | 	token := &oauth2.Token{} | ||||||
|  | 	g.Expect(getIDToken(token)).To(Equal("")) | ||||||
|  | 
 | ||||||
|  | 	extraToken := token.WithExtra(map[string]interface{}{ | ||||||
|  | 		"id_token": idToken, | ||||||
|  | 	}) | ||||||
|  | 	g.Expect(getIDToken(extraToken)).To(Equal(idToken)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Test_formatGroup(t *testing.T) { | ||||||
|  | 	testCases := map[string]struct { | ||||||
|  | 		rawGroup interface{} | ||||||
|  | 		expected string | ||||||
|  | 	}{ | ||||||
|  | 		"String Group": { | ||||||
|  | 			rawGroup: "group", | ||||||
|  | 			expected: "group", | ||||||
|  | 		}, | ||||||
|  | 		"Numeric Group": { | ||||||
|  | 			rawGroup: 123, | ||||||
|  | 			expected: "123", | ||||||
|  | 		}, | ||||||
|  | 		"Map Group": { | ||||||
|  | 			rawGroup: map[string]string{"id": "1", "name": "Test"}, | ||||||
|  | 			expected: "{\"id\":\"1\",\"name\":\"Test\"}", | ||||||
|  | 		}, | ||||||
|  | 		"List Group": { | ||||||
|  | 			rawGroup: []string{"First", "Second"}, | ||||||
|  | 			expected: "[\"First\",\"Second\"]", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for testName, tc := range testCases { | ||||||
|  | 		t.Run(testName, func(t *testing.T) { | ||||||
|  | 			g := NewWithT(t) | ||||||
|  | 			formattedGroup, err := formatGroup(tc.rawGroup) | ||||||
|  | 			g.Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			g.Expect(formattedGroup).To(Equal(tc.expected)) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue