Refactor OIDC to EnrichSession
This commit is contained in:
		
							parent
							
								
									4fda907830
								
							
						
					
					
						commit
						a1877434b2
					
				|  | @ -274,7 +274,7 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | |||
| 		p.SetRepository(o.BitbucketRepository) | ||||
| 	case *providers.OIDCProvider: | ||||
| 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | ||||
| 		p.UserIDClaim = o.UserIDClaim | ||||
| 		p.EmailClaim = o.UserIDClaim | ||||
| 		p.GroupsClaim = o.OIDCGroupsClaim | ||||
| 		if p.Verifier == nil { | ||||
| 			msgs = append(msgs, "oidc provider requires an oidc issuer URL") | ||||
|  |  | |||
|  | @ -2,18 +2,17 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	oidc "github.com/coreos/go-oidc" | ||||
| 	"golang.org/x/oauth2" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
| const emailClaim = "email" | ||||
|  | @ -23,7 +22,7 @@ type OIDCProvider struct { | |||
| 	*ProviderData | ||||
| 
 | ||||
| 	AllowUnverifiedEmail bool | ||||
| 	UserIDClaim          string | ||||
| 	EmailClaim           string | ||||
| 	GroupsClaim          string | ||||
| } | ||||
| 
 | ||||
|  | @ -36,10 +35,10 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { | |||
| var _ Provider = (*OIDCProvider)(nil) | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
| func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||
| 	clientSecret, err := p.GetClientSecret() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	c := oauth2.Config{ | ||||
|  | @ -52,23 +51,74 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s | |||
| 	} | ||||
| 	token, err := c.Exchange(ctx, code) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("token exchange: %v", err) | ||||
| 		return nil, fmt.Errorf("token exchange failure: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// in the initial exchange the id token is mandatory
 | ||||
| 	idToken, err := p.findVerifiedIDToken(ctx, token) | ||||
| 	return p.createSession(ctx, token, false) | ||||
| } | ||||
| 
 | ||||
| // EnrichSessionState is called after Redeem to allow providers to enrich session fields
 | ||||
| // such as User, Email, Groups with provider specific API calls.
 | ||||
| func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { | ||||
| 	if p.ProfileURL.String() == "" { | ||||
| 		if s.Email == "" { | ||||
| 			return errors.New("id_token did not contain an email and profileURL is not defined") | ||||
| 		} | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	// Try to get missing emails or groups from a profileURL
 | ||||
| 	if s.Email == "" || len(s.Groups) == 0 { | ||||
| 		err := p.callProfileURL(ctx, s) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf("Warning: Profile URL request failed: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// If a mandatory email wasn't set, error at this point.
 | ||||
| 	if s.Email == "" { | ||||
| 		return errors.New("neither the id_token nor the profileURL set an email") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // callProfileURL enriches a session's Email & Groups via the JSON response of
 | ||||
| // an OIDC profile URL
 | ||||
| func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionState) error { | ||||
| 	respJSON, err := requests.New(p.ProfileURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(makeOIDCHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not verify id_token: %v", err) | ||||
| 	} else if idToken == nil { | ||||
| 		return nil, fmt.Errorf("token response did not contain an id_token") | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	s, err = p.createSessionState(ctx, token, idToken) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("unable to update session: %v", err) | ||||
| 	email, err := respJSON.Get(p.EmailClaim).String() | ||||
| 	if err == nil && s.Email == "" { | ||||
| 		s.Email = email | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| 	// Handle array & singleton groups cases
 | ||||
| 	if len(s.Groups) == 0 { | ||||
| 		groups, err := respJSON.Get(p.GroupsClaim).StringArray() | ||||
| 		if err == nil { | ||||
| 			s.Groups = groups | ||||
| 		} else { | ||||
| 			group, err := respJSON.Get(p.GroupsClaim).String() | ||||
| 			if err == nil { | ||||
| 				s.Groups = []string{group} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // ValidateSessionState checks that the session's IDToken is still valid
 | ||||
| func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | ||||
| 	_, err := p.Verifier.Verify(ctx, s.IDToken) | ||||
| 	return err == nil | ||||
| } | ||||
| 
 | ||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||
|  | @ -83,14 +133,16 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S | |||
| 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn) | ||||
| 	logger.Printf("refreshed session: %s", s) | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { | ||||
| // redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
 | ||||
| // Access Token and (probably) the ID Token.
 | ||||
| func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { | ||||
| 	clientSecret, err := p.GetClientSecret() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	c := oauth2.Config{ | ||||
|  | @ -109,19 +161,14 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi | |||
| 		return fmt.Errorf("failed to get token: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// in the token refresh response the id_token is optional
 | ||||
| 	idToken, err := p.findVerifiedIDToken(ctx, token) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("unable to extract id_token from response: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	newSession, err := p.createSessionState(ctx, token, idToken) | ||||
| 	newSession, err := p.createSession(ctx, token, true) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("unable create new session state from response: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// It's possible that if the refresh token isn't in the token response the session will not contain an id token
 | ||||
| 	// if it doesn't it's probably better to retain the old one
 | ||||
| 	// It's possible that if the refresh token isn't in the token response the
 | ||||
| 	// session will not contain an id token.
 | ||||
| 	// If it doesn't it's probably better to retain the old one
 | ||||
| 	if newSession.IDToken != "" { | ||||
| 		s.IDToken = newSession.IDToken | ||||
| 		s.Email = newSession.Email | ||||
|  | @ -135,102 +182,113 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi | |||
| 	s.CreatedAt = newSession.CreatedAt | ||||
| 	s.ExpiresOn = newSession.ExpiresOn | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { | ||||
| 
 | ||||
| 	getIDToken := func() (string, bool) { | ||||
| 		rawIDToken, _ := token.Extra("id_token").(string) | ||||
| 		return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0 | ||||
| 	} | ||||
| 
 | ||||
| 	if rawIDToken, present := getIDToken(); present { | ||||
| 		verifiedIDToken, err := p.Verifier.Verify(ctx, rawIDToken) | ||||
| 		return verifiedIDToken, err | ||||
| 	} | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { | ||||
| 
 | ||||
| 	var newSession *sessions.SessionState | ||||
| 
 | ||||
| 	if idToken == nil { | ||||
| 		newSession = &sessions.SessionState{} | ||||
| 	} else { | ||||
| 		var err error | ||||
| 		newSession, err = p.createSessionStateInternal(ctx, idToken, token) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	newSession.AccessToken = token.AccessToken | ||||
| 	newSession.RefreshToken = token.RefreshToken | ||||
| 	newSession.CreatedAt = &created | ||||
| 	newSession.ExpiresOn = &token.Expiry | ||||
| 	return newSession, nil | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // CreateSessionFromToken converts Bearer IDTokens into sessions
 | ||||
| func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | ||||
| 	idToken, err := p.Verifier.Verify(ctx, token) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	newSession, err := p.createSessionStateInternal(ctx, idToken, nil) | ||||
| 	ss, err := p.buildSessionFromClaims(idToken) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	newSession.AccessToken = token | ||||
| 	newSession.IDToken = token | ||||
| 	newSession.RefreshToken = "" | ||||
| 	newSession.ExpiresOn = &idToken.Expiry | ||||
| 
 | ||||
| 	return newSession, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { | ||||
| 
 | ||||
| 	newSession := &sessions.SessionState{} | ||||
| 
 | ||||
| 	if idToken == nil { | ||||
| 		return newSession, nil | ||||
| 	// Allow empty Email in Bearer case since we can't hit the ProfileURL
 | ||||
| 	if ss.Email == "" { | ||||
| 		ss.Email = ss.User | ||||
| 	} | ||||
| 
 | ||||
| 	claims, err := p.findClaimsFromIDToken(ctx, idToken, token) | ||||
| 	ss.AccessToken = token | ||||
| 	ss.IDToken = token | ||||
| 	ss.RefreshToken = "" | ||||
| 	ss.ExpiresOn = &idToken.Expiry | ||||
| 
 | ||||
| 	return ss, nil | ||||
| } | ||||
| 
 | ||||
| // createSession takes an oauth2.Token and creates a SessionState from it.
 | ||||
| // It alters behavior if called from Redeem vs Refresh
 | ||||
| func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { | ||||
| 	idToken, err := p.findVerifiedIDToken(ctx, token) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not verify id_token: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// IDToken is mandatory in Redeem but optional in Refresh
 | ||||
| 	if idToken == nil && !refresh { | ||||
| 		return nil, errors.New("token response did not contain an id_token") | ||||
| 	} | ||||
| 
 | ||||
| 	ss, err := p.buildSessionFromClaims(idToken) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	ss.AccessToken = token.AccessToken | ||||
| 	ss.RefreshToken = token.RefreshToken | ||||
| 	ss.IDToken = getIDToken(token) | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	ss.CreatedAt = &created | ||||
| 	ss.ExpiresOn = &token.Expiry | ||||
| 
 | ||||
| 	return ss, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { | ||||
| 	rawIDToken := getIDToken(token) | ||||
| 	if strings.TrimSpace(rawIDToken) != "" { | ||||
| 		return p.Verifier.Verify(ctx, rawIDToken) | ||||
| 	} | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
 | ||||
| // with non-Token related fields.
 | ||||
| func (p *OIDCProvider) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { | ||||
| 	ss := &sessions.SessionState{} | ||||
| 
 | ||||
| 	if idToken == nil { | ||||
| 		return ss, nil | ||||
| 	} | ||||
| 
 | ||||
| 	claims, err := p.getClaims(idToken) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if token != nil { | ||||
| 		newSession.IDToken = token.Extra("id_token").(string) | ||||
| 	ss.User = claims.Subject | ||||
| 	ss.Email = claims.Email | ||||
| 	ss.Groups = claims.Groups | ||||
| 
 | ||||
| 	// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
 | ||||
| 	if pref, ok := claims.rawClaims["preferred_username"].(string); ok { | ||||
| 		ss.PreferredUsername = pref | ||||
| 	} | ||||
| 
 | ||||
| 	newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future
 | ||||
| 
 | ||||
| 	newSession.User = claims.Subject | ||||
| 	newSession.Groups = claims.Groups | ||||
| 	newSession.PreferredUsername = claims.PreferredUsername | ||||
| 
 | ||||
| 	verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail | ||||
| 	verifyEmail := (p.EmailClaim == emailClaim) && !p.AllowUnverifiedEmail | ||||
| 	if verifyEmail && claims.Verified != nil && !*claims.Verified { | ||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.UserID) | ||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||
| 	} | ||||
| 
 | ||||
| 	return newSession, nil | ||||
| 	return ss, nil | ||||
| } | ||||
| 
 | ||||
| // ValidateSessionState checks that the session's IDToken is still valid
 | ||||
| func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | ||||
| 	_, err := p.Verifier.Verify(ctx, s.IDToken) | ||||
| 	return err == nil | ||||
| type OIDCClaims struct { | ||||
| 	Subject  string   `json:"sub"` | ||||
| 	Email    string   `json:"-"` | ||||
| 	Groups   []string `json:"-"` | ||||
| 	Verified *bool    `json:"email_verified"` | ||||
| 
 | ||||
| 	rawClaims map[string]interface{} | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { | ||||
| // getClaims extracts IDToken claims into an OIDCClaims
 | ||||
| func (p *OIDCProvider) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { | ||||
| 	claims := &OIDCClaims{} | ||||
| 
 | ||||
| 	// Extract default claims.
 | ||||
|  | @ -242,86 +300,28 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | |||
| 		return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	userID := claims.rawClaims[p.UserIDClaim] | ||||
| 	if userID != nil { | ||||
| 		claims.UserID = fmt.Sprint(userID) | ||||
| 	} | ||||
| 
 | ||||
| 	claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims) | ||||
| 
 | ||||
| 	// userID claim was not present or was empty in the ID Token
 | ||||
| 	if claims.UserID == "" { | ||||
| 		// BearerToken case, allow empty UserID
 | ||||
| 		// ProfileURL checks below won't work since we don't have an access token
 | ||||
| 		if token == nil { | ||||
| 			claims.UserID = claims.Subject | ||||
| 			return claims, nil | ||||
| 		} | ||||
| 
 | ||||
| 		profileURL := p.ProfileURL.String() | ||||
| 		if profileURL == "" || token.AccessToken == "" { | ||||
| 			return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim) | ||||
| 		} | ||||
| 
 | ||||
| 		// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
 | ||||
| 		// contents at the profileURL contains the email.
 | ||||
| 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | ||||
| 		respJSON, err := requests.New(profileURL). | ||||
| 			WithContext(ctx). | ||||
| 			WithHeaders(makeOIDCHeader(token.AccessToken)). | ||||
| 			Do(). | ||||
| 			UnmarshalJSON() | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		userID, err := respJSON.Get(p.UserIDClaim).String() | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained user ID claim (%q)", p.UserIDClaim) | ||||
| 		} | ||||
| 
 | ||||
| 		claims.UserID = userID | ||||
| 	email := claims.rawClaims[p.EmailClaim] | ||||
| 	if email != nil { | ||||
| 		claims.Email = fmt.Sprint(email) | ||||
| 	} | ||||
| 	claims.Groups = p.extractGroups(claims.rawClaims) | ||||
| 
 | ||||
| 	return claims, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string { | ||||
| func (p *OIDCProvider) extractGroups(claims map[string]interface{}) []string { | ||||
| 	groups := []string{} | ||||
| 
 | ||||
| 	rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{}) | ||||
| 	rawGroups, ok := claims[p.GroupsClaim].([]interface{}) | ||||
| 	if rawGroups != nil && ok { | ||||
| 		for _, rawGroup := range rawGroups { | ||||
| 			formattedGroup, err := formatGroup(rawGroup) | ||||
| 			if err != nil { | ||||
| 				logger.Errorf("unable to format group of type %s with error %s", reflect.TypeOf(rawGroup), err) | ||||
| 				logger.Errorf("Warning: unable to format group of type %s with error %s", | ||||
| 					reflect.TypeOf(rawGroup), err) | ||||
| 				continue | ||||
| 			} | ||||
| 			groups = append(groups, formattedGroup) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return groups | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func formatGroup(rawGroup interface{}) (string, error) { | ||||
| 	group, ok := rawGroup.(string) | ||||
| 	if !ok { | ||||
| 		jsonGroup, err := json.Marshal(rawGroup) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		group = string(jsonGroup) | ||||
| 	} | ||||
| 	return group, nil | ||||
| } | ||||
| 
 | ||||
| type OIDCClaims struct { | ||||
| 	rawClaims         map[string]interface{} | ||||
| 	UserID            string | ||||
| 	Subject           string   `json:"sub"` | ||||
| 	Verified          *bool    `json:"email_verified"` | ||||
| 	PreferredUsername string   `json:"preferred_username"` | ||||
| 	Groups            []string `json:"-"` | ||||
| } | ||||
|  |  | |||
|  | @ -154,7 +154,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { | |||
| 
 | ||||
| 	p := &OIDCProvider{ | ||||
| 		ProviderData: providerData, | ||||
| 		UserIDClaim:  "email", | ||||
| 		EmailClaim:   "email", | ||||
| 	} | ||||
| 
 | ||||
| 	return p | ||||
|  | @ -225,7 +225,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { | |||
| 	}) | ||||
| 
 | ||||
| 	server, provider := newTestSetup(body) | ||||
| 	provider.UserIDClaim = "phone_number" | ||||
| 	provider.EmailClaim = "phone_number" | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") | ||||
|  | @ -233,6 +233,256 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { | |||
| 	assert.Equal(t, defaultIDToken.Phone, session.Email) | ||||
| } | ||||
| 
 | ||||
| func TestOIDCProvider_EnrichSession(t *testing.T) { | ||||
| 	const ( | ||||
| 		idToken      = "Unchanged ID Token" | ||||
| 		accessToken  = "Unchanged Access Token" | ||||
| 		refreshToken = "Unchanged Refresh Token" | ||||
| 	) | ||||
| 
 | ||||
| 	testCases := map[string]struct { | ||||
| 		ExistingSession *sessions.SessionState | ||||
| 		EmailClaim      string | ||||
| 		GroupsClaim     string | ||||
| 		ProfileJSON     map[string]interface{} | ||||
| 		ExpectedError   error | ||||
| 		ExpectedSession *sessions.SessionState | ||||
| 	}{ | ||||
| 		"Already Populated": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				Groups:       []string{"already", "populated"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "email", | ||||
| 			GroupsClaim: "groups", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"email":  "new@thing.com", | ||||
| 				"groups": []string{"new", "thing"}, | ||||
| 			}, | ||||
| 			ExpectedError: nil, | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				Groups:       []string{"already", "populated"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"Missing Email": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "missing.email", | ||||
| 				Groups:       []string{"already", "populated"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "email", | ||||
| 			GroupsClaim: "groups", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"email":  "found@email.com", | ||||
| 				"groups": []string{"new", "thing"}, | ||||
| 			}, | ||||
| 			ExpectedError: nil, | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "missing.email", | ||||
| 				Email:        "found@email.com", | ||||
| 				Groups:       []string{"already", "populated"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 
 | ||||
| 		"Missing Email Only in Profile URL": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "missing.email", | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "email", | ||||
| 			GroupsClaim: "groups", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"email": "found@email.com", | ||||
| 			}, | ||||
| 			ExpectedError: nil, | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "missing.email", | ||||
| 				Email:        "found@email.com", | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"Missing Email with Custom Claim": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "missing.email", | ||||
| 				Groups:       []string{"already", "populated"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "weird", | ||||
| 			GroupsClaim: "groups", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"weird":  "weird@claim.com", | ||||
| 				"groups": []string{"new", "thing"}, | ||||
| 			}, | ||||
| 			ExpectedError: nil, | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "missing.email", | ||||
| 				Email:        "weird@claim.com", | ||||
| 				Groups:       []string{"already", "populated"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"Missing Email not in Profile URL": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "missing.email", | ||||
| 				Groups:       []string{"already", "populated"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "email", | ||||
| 			GroupsClaim: "groups", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"groups": []string{"new", "thing"}, | ||||
| 			}, | ||||
| 			ExpectedError: errors.New("neither the id_token nor the profileURL set an email"), | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "missing.email", | ||||
| 				Groups:       []string{"already", "populated"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"Missing Groups": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				Groups:       []string{}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "email", | ||||
| 			GroupsClaim: "groups", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"email":  "new@thing.com", | ||||
| 				"groups": []string{"new", "thing"}, | ||||
| 			}, | ||||
| 			ExpectedError: nil, | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				Groups:       []string{"new", "thing"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"Missing Groups with Custom Claim": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				Groups:       nil, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "email", | ||||
| 			GroupsClaim: "roles", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"email": "new@thing.com", | ||||
| 				"roles": []string{"new", "thing", "roles"}, | ||||
| 			}, | ||||
| 			ExpectedError: nil, | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				Groups:       []string{"new", "thing", "roles"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"Missing Groups String Profile URL Response": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				Groups:       []string{}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "email", | ||||
| 			GroupsClaim: "groups", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"email":  "new@thing.com", | ||||
| 				"groups": "singleton", | ||||
| 			}, | ||||
| 			ExpectedError: nil, | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				Groups:       []string{"singleton"}, | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"Missing Groups in both Claims and Profile URL": { | ||||
| 			ExistingSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 			EmailClaim:  "email", | ||||
| 			GroupsClaim: "groups", | ||||
| 			ProfileJSON: map[string]interface{}{ | ||||
| 				"email": "new@thing.com", | ||||
| 			}, | ||||
| 			ExpectedError: nil, | ||||
| 			ExpectedSession: &sessions.SessionState{ | ||||
| 				User:         "already", | ||||
| 				Email:        "already@populated.com", | ||||
| 				IDToken:      idToken, | ||||
| 				AccessToken:  accessToken, | ||||
| 				RefreshToken: refreshToken, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	for testName, tc := range testCases { | ||||
| 		t.Run(testName, func(t *testing.T) { | ||||
| 			jsonResp, err := json.Marshal(tc.ProfileJSON) | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			server, provider := newTestSetup(jsonResp) | ||||
| 			provider.ProfileURL, err = url.Parse(server.URL) | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			provider.EmailClaim = tc.EmailClaim | ||||
| 			provider.GroupsClaim = tc.GroupsClaim | ||||
| 			defer server.Close() | ||||
| 
 | ||||
| 			err = provider.EnrichSession(context.Background(), tc.ExistingSession) | ||||
| 			assert.Equal(t, tc.ExpectedError, err) | ||||
| 			assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { | ||||
| 
 | ||||
| 	idToken, _ := newSignedTestIDToken(defaultIDToken) | ||||
|  | @ -361,7 +611,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { | ||||
| func TestOIDCProvider_findVerifiedIDToken(t *testing.T) { | ||||
| 
 | ||||
| 	server, provider := newTestSetup([]byte("")) | ||||
| 
 | ||||
|  | @ -397,31 +647,3 @@ func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { | |||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, true, verifiedIDToken == nil) | ||||
| } | ||||
| 
 | ||||
| func Test_formatGroup(t *testing.T) { | ||||
| 	testCases := map[string]struct { | ||||
| 		RawGroup                    interface{} | ||||
| 		ExpectedFormattedGroupValue string | ||||
| 	}{ | ||||
| 		"String Group": { | ||||
| 			RawGroup:                    "group", | ||||
| 			ExpectedFormattedGroupValue: "group", | ||||
| 		}, | ||||
| 		"Map Group": { | ||||
| 			RawGroup:                    map[string]string{"id": "1", "name": "Test"}, | ||||
| 			ExpectedFormattedGroupValue: "{\"id\":\"1\",\"name\":\"Test\"}", | ||||
| 		}, | ||||
| 		"List Group": { | ||||
| 			RawGroup:                    []string{"First", "Second"}, | ||||
| 			ExpectedFormattedGroupValue: "[\"First\",\"Second\"]", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for testName, tc := range testCases { | ||||
| 		t.Run(testName, func(t *testing.T) { | ||||
| 			formattedGroup, err := formatGroup(tc.RawGroup) | ||||
| 			assert.Nil(t, err) | ||||
| 			assert.Equal(t, tc.ExpectedFormattedGroupValue, formattedGroup) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -1,9 +1,12 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 
 | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -55,3 +58,23 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va | |||
| 	a.RawQuery = params.Encode() | ||||
| 	return a | ||||
| } | ||||
| 
 | ||||
| func getIDToken(token *oauth2.Token) string { | ||||
| 	idToken, ok := token.Extra("id_token").(string) | ||||
| 	if !ok { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return idToken | ||||
| } | ||||
| 
 | ||||
| func formatGroup(rawGroup interface{}) (string, error) { | ||||
| 	group, ok := rawGroup.(string) | ||||
| 	if !ok { | ||||
| 		jsonGroup, err := json.Marshal(rawGroup) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		group = string(jsonGroup) | ||||
| 	} | ||||
| 	return group, nil | ||||
| } | ||||
|  |  | |||
|  | @ -5,9 +5,10 @@ import ( | |||
| 	"testing" | ||||
| 
 | ||||
| 	. "github.com/onsi/gomega" | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
| func TestMakeAuhtorizationHeader(t *testing.T) { | ||||
| func Test_makeAuthorizationHeader(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		name         string | ||||
| 		prefix       string | ||||
|  | @ -64,3 +65,49 @@ func TestMakeAuhtorizationHeader(t *testing.T) { | |||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Test_getIDToken(t *testing.T) { | ||||
| 	const idToken = "eyJfoobar.eyJfoobar.12345asdf" | ||||
| 	g := NewWithT(t) | ||||
| 
 | ||||
| 	token := &oauth2.Token{} | ||||
| 	g.Expect(getIDToken(token)).To(Equal("")) | ||||
| 
 | ||||
| 	extraToken := token.WithExtra(map[string]interface{}{ | ||||
| 		"id_token": idToken, | ||||
| 	}) | ||||
| 	g.Expect(getIDToken(extraToken)).To(Equal(idToken)) | ||||
| } | ||||
| 
 | ||||
| func Test_formatGroup(t *testing.T) { | ||||
| 	testCases := map[string]struct { | ||||
| 		rawGroup interface{} | ||||
| 		expected string | ||||
| 	}{ | ||||
| 		"String Group": { | ||||
| 			rawGroup: "group", | ||||
| 			expected: "group", | ||||
| 		}, | ||||
| 		"Numeric Group": { | ||||
| 			rawGroup: 123, | ||||
| 			expected: "123", | ||||
| 		}, | ||||
| 		"Map Group": { | ||||
| 			rawGroup: map[string]string{"id": "1", "name": "Test"}, | ||||
| 			expected: "{\"id\":\"1\",\"name\":\"Test\"}", | ||||
| 		}, | ||||
| 		"List Group": { | ||||
| 			rawGroup: []string{"First", "Second"}, | ||||
| 			expected: "[\"First\",\"Second\"]", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for testName, tc := range testCases { | ||||
| 		t.Run(testName, func(t *testing.T) { | ||||
| 			g := NewWithT(t) | ||||
| 			formattedGroup, err := formatGroup(tc.rawGroup) | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 			g.Expect(formattedGroup).To(Equal(tc.expected)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue