Cleanup internalSession params & handle profileURL Bearer case better
`findClaimsFromIDToken` would always have a `nil` access token and not be able to hit the userinfo endpoint in Bearer case. If access token is nil, default to legacy `session.Email = claim.Subject` that all JWT bearers used to have, even if a valid profileURL is present.
This commit is contained in:
		
							parent
							
								
									dcc75410a8
								
							
						
					
					
						commit
						0645e19c24
					
				|  | @ -157,7 +157,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | ||||||
| 		newSession = &sessions.SessionState{} | 		newSession = &sessions.SessionState{} | ||||||
| 	} else { | 	} else { | ||||||
| 		var err error | 		var err error | ||||||
| 		newSession, err = p.createSessionStateInternal(ctx, token.Extra("id_token").(string), idToken, token, false) | 		newSession, err = p.createSessionStateInternal(ctx, idToken, token) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  | @ -172,7 +172,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { | func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { | ||||||
| 	newSession, err := p.createSessionStateInternal(ctx, rawIDToken, idToken, nil, true) | 	newSession, err := p.createSessionStateInternal(ctx, idToken, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -185,24 +185,22 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, ra | ||||||
| 	return newSession, nil | 	return newSession, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, rawIDToken string, idToken *oidc.IDToken, token *oauth2.Token, bearer bool) (*sessions.SessionState, error) { | func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { | ||||||
| 
 | 
 | ||||||
| 	newSession := &sessions.SessionState{} | 	newSession := &sessions.SessionState{} | ||||||
| 
 | 
 | ||||||
| 	if idToken == nil { | 	if idToken == nil { | ||||||
| 		return newSession, nil | 		return newSession, nil | ||||||
| 	} | 	} | ||||||
| 	accessToken := "" |  | ||||||
| 	if token != nil { |  | ||||||
| 		accessToken = token.AccessToken |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	claims, err := p.findClaimsFromIDToken(ctx, idToken, accessToken, p.ProfileURL.String(), bearer) | 	claims, err := p.findClaimsFromIDToken(ctx, idToken, token) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) | 		return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	newSession.IDToken = rawIDToken | 	if token != nil { | ||||||
|  | 		newSession.IDToken = token.Extra("id_token").(string) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future
 | 	newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future
 | ||||||
| 
 | 
 | ||||||
|  | @ -230,7 +228,7 @@ func getOIDCHeader(accessToken string) http.Header { | ||||||
| 	return header | 	return header | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, accessToken string, profileURL string, bearer bool) (*OIDCClaims, error) { | func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { | ||||||
| 	claims := &OIDCClaims{} | 	claims := &OIDCClaims{} | ||||||
| 	// Extract default claims.
 | 	// Extract default claims.
 | ||||||
| 	if err := idToken.Claims(&claims); err != nil { | 	if err := idToken.Claims(&claims); err != nil { | ||||||
|  | @ -248,11 +246,15 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | ||||||
| 
 | 
 | ||||||
| 	// userID claim was not present or was empty in the ID Token
 | 	// userID claim was not present or was empty in the ID Token
 | ||||||
| 	if claims.UserID == "" { | 	if claims.UserID == "" { | ||||||
| 		if profileURL == "" { | 		// BearerToken case, allow empty UserID
 | ||||||
| 			if bearer { | 		// ProfileURL checks below won't work since we don't have an access token
 | ||||||
|  | 		if token == nil { | ||||||
| 			claims.UserID = claims.Subject | 			claims.UserID = claims.Subject | ||||||
| 			return claims, nil | 			return claims, nil | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		profileURL := p.ProfileURL.String() | ||||||
|  | 		if profileURL == "" || token.AccessToken == "" { | ||||||
| 			return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim) | 			return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -261,7 +263,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | ||||||
| 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | ||||||
| 		respJSON, err := requests.New(profileURL). | 		respJSON, err := requests.New(profileURL). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
| 			WithHeaders(getOIDCHeader(accessToken)). | 			WithHeaders(getOIDCHeader(token.AccessToken)). | ||||||
| 			Do(). | 			Do(). | ||||||
| 			UnmarshalJSON() | 			UnmarshalJSON() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  |  | ||||||
|  | @ -274,23 +274,18 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	testCases := map[string]struct { | 	testCases := map[string]struct { | ||||||
| 		IDToken       idTokenClaims | 		IDToken       idTokenClaims | ||||||
| 		ProfileURL    bool | 		ExpectedUser  string | ||||||
| 		ExpectedEmail string | 		ExpectedEmail string | ||||||
| 	}{ | 	}{ | ||||||
| 		"Default IDToken": { | 		"Default IDToken": { | ||||||
| 			IDToken:       defaultIDToken, | 			IDToken:       defaultIDToken, | ||||||
| 			ProfileURL:    true, | 			ExpectedUser:  defaultIDToken.Subject, | ||||||
| 			ExpectedEmail: profileURLEmail, | 			ExpectedEmail: defaultIDToken.Email, | ||||||
| 		}, | 		}, | ||||||
| 		"Minimal IDToken with no OIDC Profile URL": { | 		"Minimal IDToken with no email claim": { | ||||||
| 			IDToken:       minimalIDToken, | 			IDToken:       minimalIDToken, | ||||||
| 			ProfileURL:    false, | 			ExpectedUser:  minimalIDToken.Subject, | ||||||
| 			ExpectedEmail: "", | 			ExpectedEmail: minimalIDToken.Subject, | ||||||
| 		}, |  | ||||||
| 		"Minimal IDToken with OIDC Profile URL": { |  | ||||||
| 			IDToken:       minimalIDToken, |  | ||||||
| 			ProfileURL:    true, |  | ||||||
| 			ExpectedEmail: profileURLEmail, |  | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	for testName, tc := range testCases { | 	for testName, tc := range testCases { | ||||||
|  | @ -298,9 +293,6 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { | ||||||
| 			jsonResp := []byte(fmt.Sprintf(`{"email":"%s"}`, profileURLEmail)) | 			jsonResp := []byte(fmt.Sprintf(`{"email":"%s"}`, profileURLEmail)) | ||||||
| 			server, provider := newTestSetup(jsonResp) | 			server, provider := newTestSetup(jsonResp) | ||||||
| 			defer server.Close() | 			defer server.Close() | ||||||
| 			if !tc.ProfileURL { |  | ||||||
| 				provider.ProfileURL = &url.URL{} |  | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
|  | @ -315,13 +307,8 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { | ||||||
| 			ss, err := provider.CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken) | 			ss, err := provider.CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken) | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			if tc.ExpectedEmail != "" { | 			assert.Equal(t, tc.ExpectedUser, ss.User) | ||||||
| 			assert.Equal(t, tc.ExpectedEmail, ss.Email) | 			assert.Equal(t, tc.ExpectedEmail, ss.Email) | ||||||
| 				assert.NotEqual(t, ss.Email, ss.User) |  | ||||||
| 			} else { |  | ||||||
| 				assert.Equal(t, tc.IDToken.Subject, ss.Email) |  | ||||||
| 				assert.Equal(t, ss.Email, ss.User) |  | ||||||
| 			} |  | ||||||
| 			assert.Equal(t, rawIDToken, ss.IDToken) | 			assert.Equal(t, rawIDToken, ss.IDToken) | ||||||
| 			assert.Equal(t, rawIDToken, ss.AccessToken) | 			assert.Equal(t, rawIDToken, ss.AccessToken) | ||||||
| 			assert.Equal(t, "", ss.RefreshToken) | 			assert.Equal(t, "", ss.RefreshToken) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue