diff --git a/pkg/validation/options.go b/pkg/validation/options.go index fc89000b..0cea4b2b 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -274,7 +274,7 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { p.SetRepository(o.BitbucketRepository) case *providers.OIDCProvider: p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail - p.UserIDClaim = o.UserIDClaim + p.EmailClaim = o.UserIDClaim p.GroupsClaim = o.OIDCGroupsClaim if p.Verifier == nil { msgs = append(msgs, "oidc provider requires an oidc issuer URL") diff --git a/providers/oidc.go b/providers/oidc.go index 704e3341..15020282 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -2,18 +2,17 @@ package providers import ( "context" - "encoding/json" + "errors" "fmt" "reflect" "strings" "time" - oidc "github.com/coreos/go-oidc" - "golang.org/x/oauth2" - + "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" + "golang.org/x/oauth2" ) const emailClaim = "email" @@ -23,7 +22,7 @@ type OIDCProvider struct { *ProviderData AllowUnverifiedEmail bool - UserIDClaim string + EmailClaim string GroupsClaim string } @@ -36,10 +35,10 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { var _ Provider = (*OIDCProvider)(nil) // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { clientSecret, err := p.GetClientSecret() if err != nil { - return + return nil, err } c := oauth2.Config{ @@ -52,23 +51,74 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s } token, err := c.Exchange(ctx, code) if err != nil { - return nil, fmt.Errorf("token exchange: %v", err) + return nil, fmt.Errorf("token exchange failure: %v", err) } - // in the initial exchange the id token is mandatory - idToken, err := p.findVerifiedIDToken(ctx, token) + return p.createSession(ctx, token, false) +} + +// EnrichSessionState is called after Redeem to allow providers to enrich session fields +// such as User, Email, Groups with provider specific API calls. +func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { + if p.ProfileURL.String() == "" { + if s.Email == "" { + return errors.New("id_token did not contain an email and profileURL is not defined") + } + return nil + } + + // Try to get missing emails or groups from a profileURL + if s.Email == "" || len(s.Groups) == 0 { + err := p.callProfileURL(ctx, s) + if err != nil { + logger.Errorf("Warning: Profile URL request failed: %v", err) + } + } + + // If a mandatory email wasn't set, error at this point. + if s.Email == "" { + return errors.New("neither the id_token nor the profileURL set an email") + } + return nil +} + +// callProfileURL enriches a session's Email & Groups via the JSON response of +// an OIDC profile URL +func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionState) error { + respJSON, err := requests.New(p.ProfileURL.String()). + WithContext(ctx). + WithHeaders(makeOIDCHeader(s.AccessToken)). + Do(). + UnmarshalJSON() if err != nil { - return nil, fmt.Errorf("could not verify id_token: %v", err) - } else if idToken == nil { - return nil, fmt.Errorf("token response did not contain an id_token") + return err } - s, err = p.createSessionState(ctx, token, idToken) - if err != nil { - return nil, fmt.Errorf("unable to update session: %v", err) + email, err := respJSON.Get(p.EmailClaim).String() + if err == nil && s.Email == "" { + s.Email = email } - return + // Handle array & singleton groups cases + if len(s.Groups) == 0 { + groups, err := respJSON.Get(p.GroupsClaim).StringArray() + if err == nil { + s.Groups = groups + } else { + group, err := respJSON.Get(p.GroupsClaim).String() + if err == nil { + s.Groups = []string{group} + } + } + } + + return nil +} + +// ValidateSessionState checks that the session's IDToken is still valid +func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { + _, err := p.Verifier.Verify(ctx, s.IDToken) + return err == nil } // RefreshSessionIfNeeded checks if the session has expired and uses the @@ -83,14 +133,16 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn) + logger.Printf("refreshed session: %s", s) return true, nil } -func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { +// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the +// Access Token and (probably) the ID Token. +func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { clientSecret, err := p.GetClientSecret() if err != nil { - return + return err } c := oauth2.Config{ @@ -109,19 +161,14 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi return fmt.Errorf("failed to get token: %v", err) } - // in the token refresh response the id_token is optional - idToken, err := p.findVerifiedIDToken(ctx, token) - if err != nil { - return fmt.Errorf("unable to extract id_token from response: %v", err) - } - - newSession, err := p.createSessionState(ctx, token, idToken) + newSession, err := p.createSession(ctx, token, true) if err != nil { return fmt.Errorf("unable create new session state from response: %v", err) } - // It's possible that if the refresh token isn't in the token response the session will not contain an id token - // if it doesn't it's probably better to retain the old one + // It's possible that if the refresh token isn't in the token response the + // session will not contain an id token. + // If it doesn't it's probably better to retain the old one if newSession.IDToken != "" { s.IDToken = newSession.IDToken s.Email = newSession.Email @@ -135,102 +182,113 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi s.CreatedAt = newSession.CreatedAt s.ExpiresOn = newSession.ExpiresOn - return -} - -func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { - - getIDToken := func() (string, bool) { - rawIDToken, _ := token.Extra("id_token").(string) - return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0 - } - - if rawIDToken, present := getIDToken(); present { - verifiedIDToken, err := p.Verifier.Verify(ctx, rawIDToken) - return verifiedIDToken, err - } - return nil, nil -} - -func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { - - var newSession *sessions.SessionState - - if idToken == nil { - newSession = &sessions.SessionState{} - } else { - var err error - newSession, err = p.createSessionStateInternal(ctx, idToken, token) - if err != nil { - return nil, err - } - } - - created := time.Now() - newSession.AccessToken = token.AccessToken - newSession.RefreshToken = token.RefreshToken - newSession.CreatedAt = &created - newSession.ExpiresOn = &token.Expiry - return newSession, nil + return nil } +// CreateSessionFromToken converts Bearer IDTokens into sessions func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { idToken, err := p.Verifier.Verify(ctx, token) if err != nil { return nil, err } - newSession, err := p.createSessionStateInternal(ctx, idToken, nil) + ss, err := p.buildSessionFromClaims(idToken) if err != nil { return nil, err } - newSession.AccessToken = token - newSession.IDToken = token - newSession.RefreshToken = "" - newSession.ExpiresOn = &idToken.Expiry - - return newSession, nil -} - -func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { - - newSession := &sessions.SessionState{} - - if idToken == nil { - return newSession, nil + // Allow empty Email in Bearer case since we can't hit the ProfileURL + if ss.Email == "" { + ss.Email = ss.User } - claims, err := p.findClaimsFromIDToken(ctx, idToken, token) + ss.AccessToken = token + ss.IDToken = token + ss.RefreshToken = "" + ss.ExpiresOn = &idToken.Expiry + + return ss, nil +} + +// createSession takes an oauth2.Token and creates a SessionState from it. +// It alters behavior if called from Redeem vs Refresh +func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { + idToken, err := p.findVerifiedIDToken(ctx, token) + if err != nil { + return nil, fmt.Errorf("could not verify id_token: %v", err) + } + + // IDToken is mandatory in Redeem but optional in Refresh + if idToken == nil && !refresh { + return nil, errors.New("token response did not contain an id_token") + } + + ss, err := p.buildSessionFromClaims(idToken) + if err != nil { + return nil, err + } + + ss.AccessToken = token.AccessToken + ss.RefreshToken = token.RefreshToken + ss.IDToken = getIDToken(token) + + created := time.Now() + ss.CreatedAt = &created + ss.ExpiresOn = &token.Expiry + + return ss, nil +} + +func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { + rawIDToken := getIDToken(token) + if strings.TrimSpace(rawIDToken) != "" { + return p.Verifier.Verify(ctx, rawIDToken) + } + return nil, nil +} + +// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState +// with non-Token related fields. +func (p *OIDCProvider) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { + ss := &sessions.SessionState{} + + if idToken == nil { + return ss, nil + } + + claims, err := p.getClaims(idToken) if err != nil { return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) } - if token != nil { - newSession.IDToken = token.Extra("id_token").(string) + ss.User = claims.Subject + ss.Email = claims.Email + ss.Groups = claims.Groups + + // TODO (@NickMeves) Deprecate for dynamic claim to session mapping + if pref, ok := claims.rawClaims["preferred_username"].(string); ok { + ss.PreferredUsername = pref } - newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future - - newSession.User = claims.Subject - newSession.Groups = claims.Groups - newSession.PreferredUsername = claims.PreferredUsername - - verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail + verifyEmail := (p.EmailClaim == emailClaim) && !p.AllowUnverifiedEmail if verifyEmail && claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.UserID) + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } - return newSession, nil + return ss, nil } -// ValidateSessionState checks that the session's IDToken is still valid -func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { - _, err := p.Verifier.Verify(ctx, s.IDToken) - return err == nil +type OIDCClaims struct { + Subject string `json:"sub"` + Email string `json:"-"` + Groups []string `json:"-"` + Verified *bool `json:"email_verified"` + + rawClaims map[string]interface{} } -func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { +// getClaims extracts IDToken claims into an OIDCClaims +func (p *OIDCProvider) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { claims := &OIDCClaims{} // Extract default claims. @@ -242,86 +300,28 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) } - userID := claims.rawClaims[p.UserIDClaim] - if userID != nil { - claims.UserID = fmt.Sprint(userID) - } - - claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims) - - // userID claim was not present or was empty in the ID Token - if claims.UserID == "" { - // BearerToken case, allow empty UserID - // ProfileURL checks below won't work since we don't have an access token - if token == nil { - claims.UserID = claims.Subject - return claims, nil - } - - profileURL := p.ProfileURL.String() - if profileURL == "" || token.AccessToken == "" { - return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim) - } - - // If the userinfo endpoint profileURL is defined, then there is a chance the userinfo - // contents at the profileURL contains the email. - // Make a query to the userinfo endpoint, and attempt to locate the email from there. - respJSON, err := requests.New(profileURL). - WithContext(ctx). - WithHeaders(makeOIDCHeader(token.AccessToken)). - Do(). - UnmarshalJSON() - if err != nil { - return nil, err - } - - userID, err := respJSON.Get(p.UserIDClaim).String() - if err != nil { - return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained user ID claim (%q)", p.UserIDClaim) - } - - claims.UserID = userID + email := claims.rawClaims[p.EmailClaim] + if email != nil { + claims.Email = fmt.Sprint(email) } + claims.Groups = p.extractGroups(claims.rawClaims) return claims, nil } -func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string { +func (p *OIDCProvider) extractGroups(claims map[string]interface{}) []string { groups := []string{} - - rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{}) + rawGroups, ok := claims[p.GroupsClaim].([]interface{}) if rawGroups != nil && ok { for _, rawGroup := range rawGroups { formattedGroup, err := formatGroup(rawGroup) if err != nil { - logger.Errorf("unable to format group of type %s with error %s", reflect.TypeOf(rawGroup), err) + logger.Errorf("Warning: unable to format group of type %s with error %s", + reflect.TypeOf(rawGroup), err) continue } groups = append(groups, formattedGroup) } } - return groups - -} - -func formatGroup(rawGroup interface{}) (string, error) { - group, ok := rawGroup.(string) - if !ok { - jsonGroup, err := json.Marshal(rawGroup) - if err != nil { - return "", err - } - group = string(jsonGroup) - } - return group, nil -} - -type OIDCClaims struct { - rawClaims map[string]interface{} - UserID string - Subject string `json:"sub"` - Verified *bool `json:"email_verified"` - PreferredUsername string `json:"preferred_username"` - Groups []string `json:"-"` } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 54df1b31..0cc6ef76 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -154,7 +154,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { p := &OIDCProvider{ ProviderData: providerData, - UserIDClaim: "email", + EmailClaim: "email", } return p @@ -225,7 +225,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { }) server, provider := newTestSetup(body) - provider.UserIDClaim = "phone_number" + provider.EmailClaim = "phone_number" defer server.Close() session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") @@ -233,6 +233,256 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { assert.Equal(t, defaultIDToken.Phone, session.Email) } +func TestOIDCProvider_EnrichSession(t *testing.T) { + const ( + idToken = "Unchanged ID Token" + accessToken = "Unchanged Access Token" + refreshToken = "Unchanged Refresh Token" + ) + + testCases := map[string]struct { + ExistingSession *sessions.SessionState + EmailClaim string + GroupsClaim string + ProfileJSON map[string]interface{} + ExpectedError error + ExpectedSession *sessions.SessionState + }{ + "Already Populated": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Email": { + ExistingSession: &sessions.SessionState{ + User: "missing.email", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "found@email.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "missing.email", + Email: "found@email.com", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + + "Missing Email Only in Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "missing.email", + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "found@email.com", + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "missing.email", + Email: "found@email.com", + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Email with Custom Claim": { + ExistingSession: &sessions.SessionState{ + User: "missing.email", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "weird", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "weird": "weird@claim.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "missing.email", + Email: "weird@claim.com", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Email not in Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "missing.email", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "groups": []string{"new", "thing"}, + }, + ExpectedError: errors.New("neither the id_token nor the profileURL set an email"), + ExpectedSession: &sessions.SessionState{ + User: "missing.email", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"new", "thing"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups with Custom Claim": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: nil, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "roles", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "roles": []string{"new", "thing", "roles"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"new", "thing", "roles"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups String Profile URL Response": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": "singleton", + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"singleton"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups in both Claims and Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + } + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + jsonResp, err := json.Marshal(tc.ProfileJSON) + assert.NoError(t, err) + + server, provider := newTestSetup(jsonResp) + provider.ProfileURL, err = url.Parse(server.URL) + assert.NoError(t, err) + + provider.EmailClaim = tc.EmailClaim + provider.GroupsClaim = tc.GroupsClaim + defer server.Close() + + err = provider.EnrichSession(context.Background(), tc.ExistingSession) + assert.Equal(t, tc.ExpectedError, err) + assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession) + }) + } +} + func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { idToken, _ := newSignedTestIDToken(defaultIDToken) @@ -361,7 +611,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { } } -func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { +func TestOIDCProvider_findVerifiedIDToken(t *testing.T) { server, provider := newTestSetup([]byte("")) @@ -397,31 +647,3 @@ func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, true, verifiedIDToken == nil) } - -func Test_formatGroup(t *testing.T) { - testCases := map[string]struct { - RawGroup interface{} - ExpectedFormattedGroupValue string - }{ - "String Group": { - RawGroup: "group", - ExpectedFormattedGroupValue: "group", - }, - "Map Group": { - RawGroup: map[string]string{"id": "1", "name": "Test"}, - ExpectedFormattedGroupValue: "{\"id\":\"1\",\"name\":\"Test\"}", - }, - "List Group": { - RawGroup: []string{"First", "Second"}, - ExpectedFormattedGroupValue: "[\"First\",\"Second\"]", - }, - } - - for testName, tc := range testCases { - t.Run(testName, func(t *testing.T) { - formattedGroup, err := formatGroup(tc.RawGroup) - assert.Nil(t, err) - assert.Equal(t, tc.ExpectedFormattedGroupValue, formattedGroup) - }) - } -} diff --git a/providers/util.go b/providers/util.go index b4b65ac6..acf20902 100644 --- a/providers/util.go +++ b/providers/util.go @@ -1,9 +1,12 @@ package providers import ( + "encoding/json" "fmt" "net/http" "net/url" + + "golang.org/x/oauth2" ) const ( @@ -55,3 +58,23 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va a.RawQuery = params.Encode() return a } + +func getIDToken(token *oauth2.Token) string { + idToken, ok := token.Extra("id_token").(string) + if !ok { + return "" + } + return idToken +} + +func formatGroup(rawGroup interface{}) (string, error) { + group, ok := rawGroup.(string) + if !ok { + jsonGroup, err := json.Marshal(rawGroup) + if err != nil { + return "", err + } + group = string(jsonGroup) + } + return group, nil +} diff --git a/providers/util_test.go b/providers/util_test.go index 798df6cb..e14ff061 100644 --- a/providers/util_test.go +++ b/providers/util_test.go @@ -5,9 +5,10 @@ import ( "testing" . "github.com/onsi/gomega" + "golang.org/x/oauth2" ) -func TestMakeAuhtorizationHeader(t *testing.T) { +func Test_makeAuthorizationHeader(t *testing.T) { testCases := []struct { name string prefix string @@ -64,3 +65,49 @@ func TestMakeAuhtorizationHeader(t *testing.T) { }) } } + +func Test_getIDToken(t *testing.T) { + const idToken = "eyJfoobar.eyJfoobar.12345asdf" + g := NewWithT(t) + + token := &oauth2.Token{} + g.Expect(getIDToken(token)).To(Equal("")) + + extraToken := token.WithExtra(map[string]interface{}{ + "id_token": idToken, + }) + g.Expect(getIDToken(extraToken)).To(Equal(idToken)) +} + +func Test_formatGroup(t *testing.T) { + testCases := map[string]struct { + rawGroup interface{} + expected string + }{ + "String Group": { + rawGroup: "group", + expected: "group", + }, + "Numeric Group": { + rawGroup: 123, + expected: "123", + }, + "Map Group": { + rawGroup: map[string]string{"id": "1", "name": "Test"}, + expected: "{\"id\":\"1\",\"name\":\"Test\"}", + }, + "List Group": { + rawGroup: []string{"First", "Second"}, + expected: "[\"First\",\"Second\"]", + }, + } + + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + g := NewWithT(t) + formattedGroup, err := formatGroup(tc.rawGroup) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(formattedGroup).To(Equal(tc.expected)) + }) + } +}