Integrate claim extractor into providers
This commit is contained in:
		
							parent
							
								
									537e596904
								
							
						
					
					
						commit
						967051314e
					
				|  | @ -103,16 +103,17 @@ func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionSt | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { | func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { | ||||||
| 	idToken, err := p.Verifier.Verify(ctx, s.IDToken) | 	claims, err := p.getClaimExtractor(s.IDToken, s.AccessToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return fmt.Errorf("could not extract claims: %v", err) | ||||||
| 	} | 	} | ||||||
| 	claims, err := p.getClaims(idToken) | 
 | ||||||
|  | 	upn, found, err := claims.GetClaim(adfsUPNClaim) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("couldn't extract claims from id_token (%v)", err) | 		return fmt.Errorf("could not extract %s claim: %v", adfsUPNClaim, err) | ||||||
| 	} | 	} | ||||||
| 	upn := claims.raw[adfsUPNClaim] | 
 | ||||||
| 	if upn != nil { | 	if found && fmt.Sprint(upn) != "" { | ||||||
| 		s.Email = fmt.Sprint(upn) | 		s.Email = fmt.Sprint(upn) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
|  |  | ||||||
|  | @ -150,9 +150,7 @@ var _ = Describe("ADFS Provider Tests", func() { | ||||||
| 	Context("with valid token", func() { | 	Context("with valid token", func() { | ||||||
| 		It("should not throw an error", func() { | 		It("should not throw an error", func() { | ||||||
| 			rawIDToken, _ := newSignedTestIDToken(defaultIDToken) | 			rawIDToken, _ := newSignedTestIDToken(defaultIDToken) | ||||||
| 			idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) | 			session, err := p.buildSessionFromClaims(rawIDToken, "") | ||||||
| 			Expect(err).To(BeNil()) |  | ||||||
| 			session, err := p.buildSessionFromClaims(idToken) |  | ||||||
| 			Expect(err).To(BeNil()) | 			Expect(err).To(BeNil()) | ||||||
| 			session.IDToken = rawIDToken | 			session.IDToken = rawIDToken | ||||||
| 			err = p.EnrichSession(context.Background(), session) | 			err = p.EnrichSession(context.Background(), session) | ||||||
|  |  | ||||||
|  | @ -15,9 +15,20 @@ func CreateAuthorizedSession() *sessions.SessionState { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func IsAuthorizedInHeader(reqHeader http.Header) bool { | func IsAuthorizedInHeader(reqHeader http.Header) bool { | ||||||
| 	return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", authorizedAccessToken) | 	return IsAuthorizedInHeaderWithToken(reqHeader, authorizedAccessToken) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func IsAuthorizedInHeaderWithToken(reqHeader http.Header, token string) bool { | ||||||
|  | 	return reqHeader.Get("Authorization") == fmt.Sprintf("Bearer %s", token) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func IsAuthorizedInURL(reqURL *url.URL) bool { | func IsAuthorizedInURL(reqURL *url.URL) bool { | ||||||
| 	return reqURL.Query().Get("access_token") == authorizedAccessToken | 	return reqURL.Query().Get("access_token") == authorizedAccessToken | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func isAuthorizedRefreshInURLWithToken(reqURL *url.URL, token string) bool { | ||||||
|  | 	if token == "" { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return reqURL.Query().Get("refresh_token") == token | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -78,6 +78,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider { | ||||||
| 	if p.ValidateURL == nil || p.ValidateURL.String() == "" { | 	if p.ValidateURL == nil || p.ValidateURL.String() == "" { | ||||||
| 		p.ValidateURL = p.ProfileURL | 		p.ValidateURL = p.ProfileURL | ||||||
| 	} | 	} | ||||||
|  | 	p.getAuthorizationHeaderFunc = makeAzureHeader | ||||||
| 
 | 
 | ||||||
| 	return &AzureProvider{ | 	return &AzureProvider{ | ||||||
| 		ProviderData: p, | 		ProviderData: p, | ||||||
|  | @ -150,7 +151,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* | ||||||
| 	session.CreatedAtNow() | 	session.CreatedAtNow() | ||||||
| 	session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) | 	session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) | ||||||
| 
 | 
 | ||||||
| 	email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) | 	email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken, session.AccessToken) | ||||||
| 
 | 
 | ||||||
| 	// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
 | 	// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
 | ||||||
| 	// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
 | 	// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
 | ||||||
|  | @ -163,7 +164,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if session.Email == "" { | 	if session.Email == "" { | ||||||
| 		email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken) | 		email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken, session.AccessToken) | ||||||
| 		if err == nil && email != "" { | 		if err == nil && email != "" { | ||||||
| 			session.Email = email | 			session.Email = email | ||||||
| 		} else { | 		} else { | ||||||
|  | @ -215,16 +216,16 @@ func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, err | ||||||
| 
 | 
 | ||||||
| // verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token
 | // verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token
 | ||||||
| // when oidc verifier is configured
 | // when oidc verifier is configured
 | ||||||
| func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token string) (string, error) { | func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, rawIDToken string, accessToken string) (string, error) { | ||||||
| 	email := "" | 	email := "" | ||||||
| 
 | 
 | ||||||
| 	if token != "" && p.Verifier != nil { | 	if rawIDToken != "" && p.Verifier != nil { | ||||||
| 		token, err := p.Verifier.Verify(ctx, token) | 		_, err := p.Verifier.Verify(ctx, rawIDToken) | ||||||
| 		// due to issues mentioned above, id_token may not be signed by AAD
 | 		// due to issues mentioned above, id_token may not be signed by AAD
 | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			claims, err := p.getClaims(token) | 			s, err := p.buildSessionFromClaims(rawIDToken, accessToken) | ||||||
| 			if err == nil { | 			if err == nil { | ||||||
| 				email = claims.Email | 				email = s.Email | ||||||
| 			} else { | 			} else { | ||||||
| 				logger.Printf("unable to get claims from token: %v", err) | 				logger.Printf("unable to get claims from token: %v", err) | ||||||
| 			} | 			} | ||||||
|  | @ -287,7 +288,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess | ||||||
| 	s.CreatedAtNow() | 	s.CreatedAtNow() | ||||||
| 	s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) | 	s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) | ||||||
| 
 | 
 | ||||||
| 	email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) | 	email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken, s.AccessToken) | ||||||
| 
 | 
 | ||||||
| 	// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
 | 	// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
 | ||||||
| 	// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
 | 	// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
 | ||||||
|  | @ -300,7 +301,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if s.Email == "" { | 	if s.Email == "" { | ||||||
| 		email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken) | 		email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken, s.AccessToken) | ||||||
| 		if err == nil && email != "" { | 		if err == nil && email != "" { | ||||||
| 			s.Email = email | 			s.Email = email | ||||||
| 		} else { | 		} else { | ||||||
|  |  | ||||||
|  | @ -13,9 +13,8 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/coreos/go-oidc/v3/oidc" | ||||||
| 	"github.com/golang-jwt/jwt" | 	"github.com/golang-jwt/jwt" | ||||||
| 
 |  | ||||||
| 	oidc "github.com/coreos/go-oidc/v3/oidc" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| 	internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" | 	internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" | ||||||
| 
 | 
 | ||||||
|  | @ -145,11 +144,11 @@ func TestAzureSetTenant(t *testing.T) { | ||||||
| 	assert.Equal(t, "openid", p.Data().Scope) | 	assert.Equal(t, "openid", p.Data().Scope) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func testAzureBackend(payload string) *httptest.Server { | func testAzureBackend(payload string, accessToken, refreshToken string) *httptest.Server { | ||||||
| 	return testAzureBackendWithError(payload, false) | 	return testAzureBackendWithError(payload, accessToken, refreshToken, false) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func testAzureBackendWithError(payload string, injectError bool) *httptest.Server { | func testAzureBackendWithError(payload string, accessToken, refreshToken string, injectError bool) *httptest.Server { | ||||||
| 	path := "/v1.0/me" | 	path := "/v1.0/me" | ||||||
| 
 | 
 | ||||||
| 	return httptest.NewServer(http.HandlerFunc( | 	return httptest.NewServer(http.HandlerFunc( | ||||||
|  | @ -163,7 +162,8 @@ func testAzureBackendWithError(payload string, injectError bool) *httptest.Serve | ||||||
| 					w.WriteHeader(200) | 					w.WriteHeader(200) | ||||||
| 				} | 				} | ||||||
| 				w.Write([]byte(payload)) | 				w.Write([]byte(payload)) | ||||||
| 			} else if !IsAuthorizedInHeader(r.Header) { | 			} else if !IsAuthorizedInHeaderWithToken(r.Header, accessToken) && | ||||||
|  | 				!isAuthorizedRefreshInURLWithToken(r.URL, refreshToken) { | ||||||
| 				w.WriteHeader(403) | 				w.WriteHeader(403) | ||||||
| 			} else { | 			} else { | ||||||
| 				w.WriteHeader(200) | 				w.WriteHeader(200) | ||||||
|  | @ -224,7 +224,7 @@ func TestAzureProviderEnrichSession(t *testing.T) { | ||||||
| 				host string | 				host string | ||||||
| 			) | 			) | ||||||
| 			if testCase.PayloadFromAzureBackend != "" { | 			if testCase.PayloadFromAzureBackend != "" { | ||||||
| 				b = testAzureBackend(testCase.PayloadFromAzureBackend) | 				b = testAzureBackend(testCase.PayloadFromAzureBackend, authorizedAccessToken, "") | ||||||
| 				defer b.Close() | 				defer b.Close() | ||||||
| 
 | 
 | ||||||
| 				bURL, _ := url.Parse(b.URL) | 				bURL, _ := url.Parse(b.URL) | ||||||
|  | @ -319,7 +319,7 @@ func TestAzureProviderRedeem(t *testing.T) { | ||||||
| 			payloadBytes, err := json.Marshal(payload) | 			payloadBytes, err := json.Marshal(payload) | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			b := testAzureBackendWithError(string(payloadBytes), testCase.InjectRedeemURLError) | 			b := testAzureBackendWithError(string(payloadBytes), accessTokenString, testCase.RefreshToken, testCase.InjectRedeemURLError) | ||||||
| 			defer b.Close() | 			defer b.Close() | ||||||
| 
 | 
 | ||||||
| 			bURL, _ := url.Parse(b.URL) | 			bURL, _ := url.Parse(b.URL) | ||||||
|  | @ -353,35 +353,44 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| func TestAzureProviderRefresh(t *testing.T) { | func TestAzureProviderRefresh(t *testing.T) { | ||||||
| 	email := "foo@example.com" | 	email := "foo@example.com" | ||||||
|  | 	subject := "foo" | ||||||
| 	idToken := idTokenClaims{ | 	idToken := idTokenClaims{ | ||||||
| 		StandardClaims: jwt.StandardClaims{Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532"}, | 		Email: email, | ||||||
| 		Email:          email} | 		StandardClaims: jwt.StandardClaims{ | ||||||
|  | 			Audience: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532", | ||||||
|  | 			Subject:  subject, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
| 	idTokenString, err := newSignedTestIDToken(idToken) | 	idTokenString, err := newSignedTestIDToken(idToken) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
| 	timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") | 	timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	newAccessToken := "new_some_access_token" | ||||||
| 	payload := azureOAuthPayload{ | 	payload := azureOAuthPayload{ | ||||||
| 		IDToken:      idTokenString, | 		IDToken:      idTokenString, | ||||||
| 		RefreshToken: "new_some_refresh_token", | 		RefreshToken: "new_some_refresh_token", | ||||||
| 		AccessToken:  "new_some_access_token", | 		AccessToken:  newAccessToken, | ||||||
| 		ExpiresOn:    timestamp.Unix(), | 		ExpiresOn:    timestamp.Unix(), | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	payloadBytes, err := json.Marshal(payload) | 	payloadBytes, err := json.Marshal(payload) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 	b := testAzureBackend(string(payloadBytes)) | 
 | ||||||
|  | 	refreshToken := "some_refresh_token" | ||||||
|  | 	b := testAzureBackend(string(payloadBytes), newAccessToken, refreshToken) | ||||||
| 	defer b.Close() | 	defer b.Close() | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testAzureProvider(bURL.Host) | 	p := testAzureProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	expires := time.Now().Add(time.Duration(-1) * time.Hour) | 	expires := time.Now().Add(time.Duration(-1) * time.Hour) | ||||||
| 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: refreshToken, IDToken: "some_id_token", ExpiresOn: &expires} | ||||||
| 
 | 
 | ||||||
| 	refreshed, err := p.RefreshSession(context.Background(), session) | 	refreshed, err := p.RefreshSession(context.Background(), session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.True(t, refreshed) | 	assert.True(t, refreshed) | ||||||
| 	assert.NotEqual(t, session, nil) | 	assert.NotEqual(t, session, nil) | ||||||
| 	assert.Equal(t, "new_some_access_token", session.AccessToken) | 	assert.Equal(t, newAccessToken, session.AccessToken) | ||||||
| 	assert.Equal(t, "new_some_refresh_token", session.RefreshToken) | 	assert.Equal(t, "new_some_refresh_token", session.RefreshToken) | ||||||
| 	assert.Equal(t, idTokenString, session.IDToken) | 	assert.Equal(t, idTokenString, session.IDToken) | ||||||
| 	assert.Equal(t, email, session.Email) | 	assert.Equal(t, email, session.Email) | ||||||
|  |  | ||||||
|  | @ -57,6 +57,8 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider { | ||||||
| 		validateURL: digitalOceanDefaultProfileURL, | 		validateURL: digitalOceanDefaultProfileURL, | ||||||
| 		scope:       digitalOceanDefaultScope, | 		scope:       digitalOceanDefaultScope, | ||||||
| 	}) | 	}) | ||||||
|  | 	p.getAuthorizationHeaderFunc = makeOIDCHeader | ||||||
|  | 
 | ||||||
| 	return &DigitalOceanProvider{ProviderData: p} | 	return &DigitalOceanProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -58,6 +58,7 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider { | ||||||
| 		validateURL: facebookDefaultProfileURL, | 		validateURL: facebookDefaultProfileURL, | ||||||
| 		scope:       facebookDefaultScope, | 		scope:       facebookDefaultScope, | ||||||
| 	}) | 	}) | ||||||
|  | 	p.getAuthorizationHeaderFunc = makeOIDCHeader | ||||||
| 	return &FacebookProvider{ProviderData: p} | 	return &FacebookProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -65,6 +65,8 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { | ||||||
| 		validateURL: linkedinDefaultValidateURL, | 		validateURL: linkedinDefaultValidateURL, | ||||||
| 		scope:       linkedinDefaultScope, | 		scope:       linkedinDefaultScope, | ||||||
| 	}) | 	}) | ||||||
|  | 	p.getAuthorizationHeaderFunc = makeLinkedInHeader | ||||||
|  | 
 | ||||||
| 	return &LinkedInProvider{ProviderData: p} | 	return &LinkedInProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,13 +1,5 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"fmt" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // NextcloudProvider represents an Nextcloud based Identity Provider
 | // NextcloudProvider represents an Nextcloud based Identity Provider
 | ||||||
| type NextcloudProvider struct { | type NextcloudProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
|  | @ -20,20 +12,11 @@ const nextCloudProviderName = "Nextcloud" | ||||||
| // NewNextcloudProvider initiates a new NextcloudProvider
 | // NewNextcloudProvider initiates a new NextcloudProvider
 | ||||||
| func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { | func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { | ||||||
| 	p.ProviderName = nextCloudProviderName | 	p.ProviderName = nextCloudProviderName | ||||||
|  | 	p.getAuthorizationHeaderFunc = makeOIDCHeader | ||||||
|  | 	if p.EmailClaim == OIDCEmailClaim { | ||||||
|  | 		// This implies the email claim has not been overridden, we should set a default
 | ||||||
|  | 		// for this provider
 | ||||||
|  | 		p.EmailClaim = "ocs.data.email" | ||||||
|  | 	} | ||||||
| 	return &NextcloudProvider{ProviderData: p} | 	return &NextcloudProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // GetEmailAddress returns the Account email address
 |  | ||||||
| func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { |  | ||||||
| 	json, err := requests.New(p.ValidateURL.String()). |  | ||||||
| 		WithContext(ctx). |  | ||||||
| 		WithHeaders(makeOIDCHeader(s.AccessToken)). |  | ||||||
| 		Do(). |  | ||||||
| 		UnmarshalJSON() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", fmt.Errorf("error making request: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	email, err := json.Get("ocs").Get("data").Get("email").String() |  | ||||||
| 	return email, err |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -1,18 +1,13 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const formatJSON = "format=json" | const formatJSON = "format=json" | ||||||
| const userPath = "/ocs/v2.php/cloud/user" |  | ||||||
| 
 | 
 | ||||||
| func testNextcloudProvider(hostname string) *NextcloudProvider { | func testNextcloudProvider(hostname string) *NextcloudProvider { | ||||||
| 	p := NewNextcloudProvider( | 	p := NewNextcloudProvider( | ||||||
|  | @ -32,23 +27,6 @@ func testNextcloudProvider(hostname string) *NextcloudProvider { | ||||||
| 	return p | 	return p | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func testNextcloudBackend(payload string) *httptest.Server { |  | ||||||
| 	path := userPath |  | ||||||
| 	query := formatJSON |  | ||||||
| 
 |  | ||||||
| 	return httptest.NewServer(http.HandlerFunc( |  | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 			if r.URL.Path != path || r.URL.RawQuery != query { |  | ||||||
| 				w.WriteHeader(404) |  | ||||||
| 			} else if !IsAuthorizedInHeader(r.Header) { |  | ||||||
| 				w.WriteHeader(403) |  | ||||||
| 			} else { |  | ||||||
| 				w.WriteHeader(200) |  | ||||||
| 				w.Write([]byte(payload)) |  | ||||||
| 			} |  | ||||||
| 		})) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestNextcloudProviderDefaults(t *testing.T) { | func TestNextcloudProviderDefaults(t *testing.T) { | ||||||
| 	p := testNextcloudProvider("") | 	p := testNextcloudProvider("") | ||||||
| 	assert.NotEqual(t, nil, p) | 	assert.NotEqual(t, nil, p) | ||||||
|  | @ -87,53 +65,3 @@ func TestNextcloudProviderOverrides(t *testing.T) { | ||||||
| 	assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON, | 	assert.Equal(t, "https://example.com/test/ocs/v2.php/cloud/user?"+formatJSON, | ||||||
| 		p.Data().ValidateURL.String()) | 		p.Data().ValidateURL.String()) | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func TestNextcloudProviderGetEmailAddress(t *testing.T) { |  | ||||||
| 	b := testNextcloudBackend("{\"ocs\": {\"data\": { \"email\": \"michael.bland@gsa.gov\"}}}") |  | ||||||
| 	defer b.Close() |  | ||||||
| 
 |  | ||||||
| 	bURL, _ := url.Parse(b.URL) |  | ||||||
| 	p := testNextcloudProvider(bURL.Host) |  | ||||||
| 	p.ValidateURL.Path = userPath |  | ||||||
| 	p.ValidateURL.RawQuery = formatJSON |  | ||||||
| 
 |  | ||||||
| 	session := CreateAuthorizedSession() |  | ||||||
| 	email, err := p.GetEmailAddress(context.Background(), session) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Note that trying to trigger the "failed building request" case is not
 |  | ||||||
| // practical, since the only way it can fail is if the URL fails to parse.
 |  | ||||||
| func TestNextcloudProviderGetEmailAddressFailedRequest(t *testing.T) { |  | ||||||
| 	b := testNextcloudBackend("unused payload") |  | ||||||
| 	defer b.Close() |  | ||||||
| 
 |  | ||||||
| 	bURL, _ := url.Parse(b.URL) |  | ||||||
| 	p := testNextcloudProvider(bURL.Host) |  | ||||||
| 	p.ValidateURL.Path = userPath |  | ||||||
| 	p.ValidateURL.RawQuery = formatJSON |  | ||||||
| 
 |  | ||||||
| 	// We'll trigger a request failure by using an unexpected access
 |  | ||||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 |  | ||||||
| 	// JSON to fail.
 |  | ||||||
| 	session := &sessions.SessionState{AccessToken: "unexpected_access_token"} |  | ||||||
| 	email, err := p.GetEmailAddress(context.Background(), session) |  | ||||||
| 	assert.NotEqual(t, nil, err) |  | ||||||
| 	assert.Equal(t, "", email) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestNextcloudProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { |  | ||||||
| 	b := testNextcloudBackend("{\"foo\": \"bar\"}") |  | ||||||
| 	defer b.Close() |  | ||||||
| 
 |  | ||||||
| 	bURL, _ := url.Parse(b.URL) |  | ||||||
| 	p := testNextcloudProvider(bURL.Host) |  | ||||||
| 	p.ValidateURL.Path = userPath |  | ||||||
| 	p.ValidateURL.RawQuery = formatJSON |  | ||||||
| 
 |  | ||||||
| 	session := CreateAuthorizedSession() |  | ||||||
| 	email, err := p.GetEmailAddress(context.Background(), session) |  | ||||||
| 	assert.NotEqual(t, nil, err) |  | ||||||
| 	assert.Equal(t, "", email) |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -5,12 +5,10 @@ import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"reflect" |  | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" |  | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -24,6 +22,8 @@ type OIDCProvider struct { | ||||||
| // NewOIDCProvider initiates a new OIDCProvider
 | // NewOIDCProvider initiates a new OIDCProvider
 | ||||||
| func NewOIDCProvider(p *ProviderData) *OIDCProvider { | func NewOIDCProvider(p *ProviderData) *OIDCProvider { | ||||||
| 	p.ProviderName = "OpenID Connect" | 	p.ProviderName = "OpenID Connect" | ||||||
|  | 	p.getAuthorizationHeaderFunc = makeOIDCHeader | ||||||
|  | 
 | ||||||
| 	return &OIDCProvider{ | 	return &OIDCProvider{ | ||||||
| 		ProviderData: p, | 		ProviderData: p, | ||||||
| 		SkipNonce:    true, | 		SkipNonce:    true, | ||||||
|  | @ -68,21 +68,6 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s | ||||||
| // EnrichSession is called after Redeem to allow providers to enrich session fields
 | // EnrichSession is called after Redeem to allow providers to enrich session fields
 | ||||||
| // such as User, Email, Groups with provider specific API calls.
 | // such as User, Email, Groups with provider specific API calls.
 | ||||||
| func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { | func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { | ||||||
| 	if p.ProfileURL.String() == "" { |  | ||||||
| 		if s.Email == "" { |  | ||||||
| 			return errors.New("id_token did not contain an email and profileURL is not defined") |  | ||||||
| 		} |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Try to get missing emails or groups from a profileURL
 |  | ||||||
| 	if s.Email == "" || s.Groups == nil { |  | ||||||
| 		err := p.enrichFromProfileURL(ctx, s) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Errorf("Warning: Profile URL request failed: %v", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// If a mandatory email wasn't set, error at this point.
 | 	// If a mandatory email wasn't set, error at this point.
 | ||||||
| 	if s.Email == "" { | 	if s.Email == "" { | ||||||
| 		return errors.New("neither the id_token nor the profileURL set an email") | 		return errors.New("neither the id_token nor the profileURL set an email") | ||||||
|  | @ -90,42 +75,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
 |  | ||||||
| // an OIDC profile URL
 |  | ||||||
| func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error { |  | ||||||
| 	respJSON, err := requests.New(p.ProfileURL.String()). |  | ||||||
| 		WithContext(ctx). |  | ||||||
| 		WithHeaders(makeOIDCHeader(s.AccessToken)). |  | ||||||
| 		Do(). |  | ||||||
| 		UnmarshalJSON() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	email, err := respJSON.Get(p.EmailClaim).String() |  | ||||||
| 	if err == nil && s.Email == "" { |  | ||||||
| 		s.Email = email |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if len(s.Groups) > 0 { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 	for _, group := range coerceArray(respJSON, p.GroupsClaim) { |  | ||||||
| 		formatted, err := formatGroup(group) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Errorf("Warning: unable to format group of type %s with error %s", |  | ||||||
| 				reflect.TypeOf(group), err) |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		s.Groups = append(s.Groups, formatted) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ValidateSession checks that the session's IDToken is still valid
 | // ValidateSession checks that the session's IDToken is still valid
 | ||||||
| func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	idToken, err := p.Verifier.Verify(ctx, s.IDToken) | 	_, err := p.Verifier.Verify(ctx, s.IDToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("id_token verification failed: %v", err) | 		logger.Errorf("id_token verification failed: %v", err) | ||||||
| 		return false | 		return false | ||||||
|  | @ -134,7 +86,7 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS | ||||||
| 	if p.SkipNonce { | 	if p.SkipNonce { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	err = p.checkNonce(s, idToken) | 	err = p.checkNonce(s) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("nonce verification failed: %v", err) | 		logger.Errorf("nonce verification failed: %v", err) | ||||||
| 		return false | 		return false | ||||||
|  | @ -212,7 +164,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	ss, err := p.buildSessionFromClaims(idToken) | 	ss, err := p.buildSessionFromClaims(token, "") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -235,7 +187,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) | ||||||
| // createSession takes an oauth2.Token and creates a SessionState from it.
 | // createSession takes an oauth2.Token and creates a SessionState from it.
 | ||||||
| // It alters behavior if called from Redeem vs Refresh
 | // It alters behavior if called from Redeem vs Refresh
 | ||||||
| func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { | func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { | ||||||
| 	idToken, err := p.verifyIDToken(ctx, token) | 	_, err := p.verifyIDToken(ctx, token) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		switch err { | 		switch err { | ||||||
| 		case ErrMissingIDToken: | 		case ErrMissingIDToken: | ||||||
|  | @ -248,14 +200,15 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	ss, err := p.buildSessionFromClaims(idToken) | 	rawIDToken := getIDToken(token) | ||||||
|  | 	ss, err := p.buildSessionFromClaims(rawIDToken, token.AccessToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	ss.AccessToken = token.AccessToken | 	ss.AccessToken = token.AccessToken | ||||||
| 	ss.RefreshToken = token.RefreshToken | 	ss.RefreshToken = token.RefreshToken | ||||||
| 	ss.IDToken = getIDToken(token) | 	ss.IDToken = rawIDToken | ||||||
| 
 | 
 | ||||||
| 	ss.CreatedAtNow() | 	ss.CreatedAtNow() | ||||||
| 	ss.SetExpiresOn(token.Expiry) | 	ss.SetExpiresOn(token.Expiry) | ||||||
|  |  | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | @ -54,6 +53,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { | ||||||
| 		Scope:       "openid profile offline_access", | 		Scope:       "openid profile offline_access", | ||||||
| 		EmailClaim:  "email", | 		EmailClaim:  "email", | ||||||
| 		GroupsClaim: "groups", | 		GroupsClaim: "groups", | ||||||
|  | 		UserClaim:   "sub", | ||||||
| 		Verifier: internaloidc.NewVerifier(oidc.NewVerifier( | 		Verifier: internaloidc.NewVerifier(oidc.NewVerifier( | ||||||
| 			oidcIssuer, | 			oidcIssuer, | ||||||
| 			mockJWKS{}, | 			mockJWKS{}, | ||||||
|  | @ -142,333 +142,6 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { | ||||||
| 	assert.Equal(t, defaultIDToken.Phone, session.Email) | 	assert.Equal(t, defaultIDToken.Phone, session.Email) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestOIDCProvider_EnrichSession(t *testing.T) { |  | ||||||
| 	testCases := map[string]struct { |  | ||||||
| 		ExistingSession *sessions.SessionState |  | ||||||
| 		EmailClaim      string |  | ||||||
| 		GroupsClaim     string |  | ||||||
| 		ProfileJSON     map[string]interface{} |  | ||||||
| 		ExpectedError   error |  | ||||||
| 		ExpectedSession *sessions.SessionState |  | ||||||
| 	}{ |  | ||||||
| 		"Already Populated": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{"already", "populated"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email":  "new@thing.com", |  | ||||||
| 				"groups": []string{"new", "thing"}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{"already", "populated"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Email": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "missing.email", |  | ||||||
| 				Groups:       []string{"already", "populated"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email":  "found@email.com", |  | ||||||
| 				"groups": []string{"new", "thing"}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "missing.email", |  | ||||||
| 				Email:        "found@email.com", |  | ||||||
| 				Groups:       []string{"already", "populated"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 
 |  | ||||||
| 		"Missing Email Only in Profile URL": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "missing.email", |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email": "found@email.com", |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "missing.email", |  | ||||||
| 				Email:        "found@email.com", |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Email with Custom Claim": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "missing.email", |  | ||||||
| 				Groups:       []string{"already", "populated"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "weird", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"weird":  "weird@claim.com", |  | ||||||
| 				"groups": []string{"new", "thing"}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "missing.email", |  | ||||||
| 				Email:        "weird@claim.com", |  | ||||||
| 				Groups:       []string{"already", "populated"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Email not in Profile URL": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "missing.email", |  | ||||||
| 				Groups:       []string{"already", "populated"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"groups": []string{"new", "thing"}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: errors.New("neither the id_token nor the profileURL set an email"), |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "missing.email", |  | ||||||
| 				Groups:       []string{"already", "populated"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Groups": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       nil, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email":  "new@thing.com", |  | ||||||
| 				"groups": []string{"new", "thing"}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{"new", "thing"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Groups with Complex Groups in Profile URL": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       nil, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email": "new@thing.com", |  | ||||||
| 				"groups": []map[string]interface{}{ |  | ||||||
| 					{ |  | ||||||
| 						"groupId": "Admin Group Id", |  | ||||||
| 						"roles":   []string{"Admin"}, |  | ||||||
| 					}, |  | ||||||
| 				}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Groups with Singleton Complex Group in Profile URL": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       nil, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email": "new@thing.com", |  | ||||||
| 				"groups": map[string]interface{}{ |  | ||||||
| 					"groupId": "Admin Group Id", |  | ||||||
| 					"roles":   []string{"Admin"}, |  | ||||||
| 				}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Empty Groups Claims": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email":  "new@thing.com", |  | ||||||
| 				"groups": []string{"new", "thing"}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Groups with Custom Claim": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       nil, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "roles", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email": "new@thing.com", |  | ||||||
| 				"roles": []string{"new", "thing", "roles"}, |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{"new", "thing", "roles"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Groups String Profile URL Response": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       nil, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email":  "new@thing.com", |  | ||||||
| 				"groups": "singleton", |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				Groups:       []string{"singleton"}, |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Groups in both Claims and Profile URL": { |  | ||||||
| 			ExistingSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 			EmailClaim:  "email", |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ProfileJSON: map[string]interface{}{ |  | ||||||
| 				"email": "new@thing.com", |  | ||||||
| 			}, |  | ||||||
| 			ExpectedError: nil, |  | ||||||
| 			ExpectedSession: &sessions.SessionState{ |  | ||||||
| 				User:         "already", |  | ||||||
| 				Email:        "already@populated.com", |  | ||||||
| 				IDToken:      idToken, |  | ||||||
| 				AccessToken:  accessToken, |  | ||||||
| 				RefreshToken: refreshToken, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	for testName, tc := range testCases { |  | ||||||
| 		t.Run(testName, func(t *testing.T) { |  | ||||||
| 			jsonResp, err := json.Marshal(tc.ProfileJSON) |  | ||||||
| 			assert.NoError(t, err) |  | ||||||
| 
 |  | ||||||
| 			server, provider := newTestOIDCSetup(jsonResp) |  | ||||||
| 			provider.ProfileURL, err = url.Parse(server.URL) |  | ||||||
| 			assert.NoError(t, err) |  | ||||||
| 
 |  | ||||||
| 			provider.EmailClaim = tc.EmailClaim |  | ||||||
| 			provider.GroupsClaim = tc.GroupsClaim |  | ||||||
| 			defer server.Close() |  | ||||||
| 
 |  | ||||||
| 			err = provider.EnrichSession(context.Background(), tc.ExistingSession) |  | ||||||
| 			assert.Equal(t, tc.ExpectedError, err) |  | ||||||
| 			assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession) |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { | func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	idToken, _ := newSignedTestIDToken(defaultIDToken) | 	idToken, _ := newSignedTestIDToken(defaultIDToken) | ||||||
|  | @ -569,7 +242,11 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { | ||||||
| 			GroupsClaim:   "groups", | 			GroupsClaim:   "groups", | ||||||
| 			ExpectedUser:  "123456789", | 			ExpectedUser:  "123456789", | ||||||
| 			ExpectedEmail: "complex@claims.com", | 			ExpectedEmail: "complex@claims.com", | ||||||
| 			ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, | 			ExpectedGroups: []string{ | ||||||
|  | 				"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", | ||||||
|  | 				"12345", | ||||||
|  | 				"Just::A::String", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	for testName, tc := range testCases { | 	for testName, tc := range testCases { | ||||||
|  |  | ||||||
|  | @ -5,20 +5,23 @@ import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"reflect" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/coreos/go-oidc/v3/oidc" | 	"github.com/coreos/go-oidc/v3/oidc" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| 	internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" | 	internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/oidc" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/util" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	OIDCEmailClaim  = "email" | 	OIDCEmailClaim  = "email" | ||||||
| 	OIDCGroupsClaim = "groups" | 	OIDCGroupsClaim = "groups" | ||||||
|  | 	// This is not exported as it's not currently user configurable
 | ||||||
|  | 	oidcUserClaim = "sub" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var OIDCAudienceClaims = []string{"aud"} | var OIDCAudienceClaims = []string{"aud"} | ||||||
|  | @ -52,6 +55,8 @@ type ProviderData struct { | ||||||
| 	// Universal Group authorization data structure
 | 	// Universal Group authorization data structure
 | ||||||
| 	// any provider can set to consume
 | 	// any provider can set to consume
 | ||||||
| 	AllowedGroups map[string]struct{} | 	AllowedGroups map[string]struct{} | ||||||
|  | 
 | ||||||
|  | 	getAuthorizationHeaderFunc func(string) http.Header | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Data returns the ProviderData
 | // Data returns the ProviderData
 | ||||||
|  | @ -99,6 +104,10 @@ func (p *ProviderData) setProviderDefaults(defaults providerDefaults) { | ||||||
| 	if p.Scope == "" { | 	if p.Scope == "" { | ||||||
| 		p.Scope = defaults.scope | 		p.Scope = defaults.scope | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	if p.UserClaim == "" { | ||||||
|  | 		p.UserClaim = oidcUserClaim | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // defaultURL will set return a default value if the given value is not set.
 | // defaultURL will set return a default value if the given value is not set.
 | ||||||
|  | @ -120,17 +129,6 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL { | ||||||
| // OIDC compliant
 | // OIDC compliant
 | ||||||
| // ****************************************************************************
 | // ****************************************************************************
 | ||||||
| 
 | 
 | ||||||
| // OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload
 |  | ||||||
| type OIDCClaims struct { |  | ||||||
| 	Subject  string   `json:"sub"` |  | ||||||
| 	Email    string   `json:"-"` |  | ||||||
| 	Groups   []string `json:"-"` |  | ||||||
| 	Verified *bool    `json:"email_verified"` |  | ||||||
| 	Nonce    string   `json:"nonce"` |  | ||||||
| 
 |  | ||||||
| 	raw map[string]interface{} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { | func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { | ||||||
| 	rawIDToken := getIDToken(token) | 	rawIDToken := getIDToken(token) | ||||||
| 	if strings.TrimSpace(rawIDToken) == "" { | 	if strings.TrimSpace(rawIDToken) == "" { | ||||||
|  | @ -144,110 +142,80 @@ func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) ( | ||||||
| 
 | 
 | ||||||
| // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
 | // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
 | ||||||
| // with non-Token related fields.
 | // with non-Token related fields.
 | ||||||
| func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { | func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (*sessions.SessionState, error) { | ||||||
| 	ss := &sessions.SessionState{} | 	ss := &sessions.SessionState{} | ||||||
| 
 | 
 | ||||||
| 	if idToken == nil { | 	if rawIDToken == "" { | ||||||
| 		return ss, nil | 		return ss, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	claims, err := p.getClaims(idToken) | 	extractor, err := p.getClaimExtractor(rawIDToken, accessToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) | 		return nil, err | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	ss.User = claims.Subject |  | ||||||
| 	ss.Email = claims.Email |  | ||||||
| 	ss.Groups = claims.Groups |  | ||||||
| 
 |  | ||||||
| 	// Allow specialized providers that embed OIDCProvider to control the User
 |  | ||||||
| 	// claim. Not exposed as a configuration flag to generic OIDC provider
 |  | ||||||
| 	// users (yet).
 |  | ||||||
| 	if p.UserClaim != "" { |  | ||||||
| 		user, ok := claims.raw[p.UserClaim].(string) |  | ||||||
| 		if !ok { |  | ||||||
| 			return nil, fmt.Errorf("unable to extract custom UserClaim (%s)", p.UserClaim) |  | ||||||
| 		} |  | ||||||
| 		ss.User = user |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Use a slice of a struct (vs map) here in case the same claim is used twice
 | ||||||
|  | 	for _, c := range []struct { | ||||||
|  | 		claim string | ||||||
|  | 		dst   interface{} | ||||||
|  | 	}{ | ||||||
|  | 		{p.UserClaim, &ss.User}, | ||||||
|  | 		{p.EmailClaim, &ss.Email}, | ||||||
|  | 		{p.GroupsClaim, &ss.Groups}, | ||||||
| 		// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
 | 		// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
 | ||||||
| 	if pref, ok := claims.raw["preferred_username"].(string); ok { | 		{"preferred_username", &ss.PreferredUsername}, | ||||||
| 		ss.PreferredUsername = pref | 	} { | ||||||
|  | 		if _, err := extractor.GetClaimInto(c.claim, c.dst); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// `email_verified` must be present and explicitly set to `false` to be
 | 	// `email_verified` must be present and explicitly set to `false` to be
 | ||||||
| 	// considered unverified.
 | 	// considered unverified.
 | ||||||
| 	verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail | 	verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail | ||||||
| 	if verifyEmail && claims.Verified != nil && !*claims.Verified { | 
 | ||||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | 	var verified bool | ||||||
|  | 	exists, err := extractor.GetClaimInto("email_verified", &verified) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if verifyEmail && exists && !verified { | ||||||
|  | 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", ss.Email) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return ss, nil | 	return ss, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getClaims extracts IDToken claims into an OIDCClaims
 | func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.ClaimExtractor, error) { | ||||||
| func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { | 	extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, p.ProfileURL, p.getAuthorizationHeader(accessToken)) | ||||||
| 	claims := &OIDCClaims{} | 	if err != nil { | ||||||
| 
 | 		return nil, fmt.Errorf("could not initialise claim extractor: %v", err) | ||||||
| 	// Extract default claims.
 |  | ||||||
| 	if err := idToken.Claims(&claims); err != nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to parse default id_token claims: %v", err) |  | ||||||
| 	} |  | ||||||
| 	// Extract custom claims.
 |  | ||||||
| 	if err := idToken.Claims(&claims.raw); err != nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	email := claims.raw[p.EmailClaim] | 	return extractor, nil | ||||||
| 	if email != nil { |  | ||||||
| 		claims.Email = fmt.Sprint(email) |  | ||||||
| 	} |  | ||||||
| 	claims.Groups = p.extractGroups(claims.raw) |  | ||||||
| 
 |  | ||||||
| 	return claims, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // checkNonce compares the session's nonce with the IDToken's nonce claim
 | // checkNonce compares the session's nonce with the IDToken's nonce claim
 | ||||||
| func (p *ProviderData) checkNonce(s *sessions.SessionState, idToken *oidc.IDToken) error { | func (p *ProviderData) checkNonce(s *sessions.SessionState) error { | ||||||
| 	claims, err := p.getClaims(idToken) | 	extractor, err := p.getClaimExtractor(s.IDToken, "") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("id_token claims extraction failed: %v", err) | 		return fmt.Errorf("id_token claims extraction failed: %v", err) | ||||||
| 	} | 	} | ||||||
| 	if !s.CheckNonce(claims.Nonce) { | 	var nonce string | ||||||
|  | 	if _, err := extractor.GetClaimInto("nonce", &nonce); err != nil { | ||||||
|  | 		return fmt.Errorf("could not extract nonce from ID Token: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !s.CheckNonce(nonce) { | ||||||
| 		return errors.New("id_token nonce claim does not match the session nonce") | 		return errors.New("id_token nonce claim does not match the session nonce") | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // extractGroups extracts groups from a claim to a list in a type safe manner.
 | func (p *ProviderData) getAuthorizationHeader(accessToken string) http.Header { | ||||||
| // If the claim isn't present, `nil` is returned. If the groups claim is
 | 	if p.getAuthorizationHeaderFunc != nil && accessToken != "" { | ||||||
| // present but empty, `[]string{}` is returned.
 | 		return p.getAuthorizationHeaderFunc(accessToken) | ||||||
| func (p *ProviderData) extractGroups(claims map[string]interface{}) []string { | 	} | ||||||
| 	rawClaim, ok := claims[p.GroupsClaim] |  | ||||||
| 	if !ok { |  | ||||||
| 	return nil | 	return nil | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Handle traditional list-based groups as well as non-standard singleton
 |  | ||||||
| 	// based groups. Both variants support complex objects if needed.
 |  | ||||||
| 	var claimGroups []interface{} |  | ||||||
| 	switch raw := rawClaim.(type) { |  | ||||||
| 	case []interface{}: |  | ||||||
| 		claimGroups = raw |  | ||||||
| 	case interface{}: |  | ||||||
| 		claimGroups = []interface{}{raw} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	groups := []string{} |  | ||||||
| 	for _, rawGroup := range claimGroups { |  | ||||||
| 		formattedGroup, err := formatGroup(rawGroup) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Errorf("Warning: unable to format group of type %s with error %s", |  | ||||||
| 				reflect.TypeOf(rawGroup), err) |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		groups = append(groups, formattedGroup) |  | ||||||
| 	} |  | ||||||
| 	return groups |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -60,16 +60,30 @@ var ( | ||||||
| 		StandardClaims: standardClaims, | 		StandardClaims: standardClaims, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	numericGroupsIDToken = idTokenClaims{ | ||||||
|  | 		Name:           "Jane Dobbs", | ||||||
|  | 		Email:          "janed@me.com", | ||||||
|  | 		Phone:          "+4798765432", | ||||||
|  | 		Picture:        "http://mugbook.com/janed/me.jpg", | ||||||
|  | 		Groups:         []interface{}{1, 2, 3}, | ||||||
|  | 		Roles:          []string{"test:c", "test:d"}, | ||||||
|  | 		Verified:       &verified, | ||||||
|  | 		Nonce:          encryption.HashNonce([]byte(oidcNonce)), | ||||||
|  | 		StandardClaims: standardClaims, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	complexGroupsIDToken = idTokenClaims{ | 	complexGroupsIDToken = idTokenClaims{ | ||||||
| 		Name:    "Complex Claim", | 		Name:    "Complex Claim", | ||||||
| 		Email:   "complex@claims.com", | 		Email:   "complex@claims.com", | ||||||
| 		Phone:   "+5439871234", | 		Phone:   "+5439871234", | ||||||
| 		Picture: "http://mugbook.com/complex/claims.jpg", | 		Picture: "http://mugbook.com/complex/claims.jpg", | ||||||
| 		Groups: []map[string]interface{}{ | 		Groups: []interface{}{ | ||||||
| 			{ | 			map[string]interface{}{ | ||||||
| 				"groupId": "Admin Group Id", | 				"groupId": "Admin Group Id", | ||||||
| 				"roles":   []string{"Admin"}, | 				"roles":   []string{"Admin"}, | ||||||
| 			}, | 			}, | ||||||
|  | 			12345, | ||||||
|  | 			"Just::A::String", | ||||||
| 		}, | 		}, | ||||||
| 		Roles:          []string{"test:simple", "test:roles"}, | 		Roles:          []string{"test:simple", "test:roles"}, | ||||||
| 		Verified:       &verified, | 		Verified:       &verified, | ||||||
|  | @ -228,6 +242,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 			AllowUnverified: false, | 			AllowUnverified: false, | ||||||
| 			EmailClaim:      "email", | 			EmailClaim:      "email", | ||||||
| 			GroupsClaim:     "groups", | 			GroupsClaim:     "groups", | ||||||
|  | 			UserClaim:       "sub", | ||||||
| 			ExpectedSession: &sessions.SessionState{ | 			ExpectedSession: &sessions.SessionState{ | ||||||
| 				User:              "123456789", | 				User:              "123456789", | ||||||
| 				Email:             "janed@me.com", | 				Email:             "janed@me.com", | ||||||
|  | @ -247,6 +262,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 			AllowUnverified: true, | 			AllowUnverified: true, | ||||||
| 			EmailClaim:      "email", | 			EmailClaim:      "email", | ||||||
| 			GroupsClaim:     "groups", | 			GroupsClaim:     "groups", | ||||||
|  | 			UserClaim:       "sub", | ||||||
| 			ExpectedSession: &sessions.SessionState{ | 			ExpectedSession: &sessions.SessionState{ | ||||||
| 				User:              "123456789", | 				User:              "123456789", | ||||||
| 				Email:             "unverified@email.com", | 				Email:             "unverified@email.com", | ||||||
|  | @ -259,10 +275,15 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 			AllowUnverified: true, | 			AllowUnverified: true, | ||||||
| 			EmailClaim:      "email", | 			EmailClaim:      "email", | ||||||
| 			GroupsClaim:     "groups", | 			GroupsClaim:     "groups", | ||||||
|  | 			UserClaim:       "sub", | ||||||
| 			ExpectedSession: &sessions.SessionState{ | 			ExpectedSession: &sessions.SessionState{ | ||||||
| 				User:  "123456789", | 				User:  "123456789", | ||||||
| 				Email: "complex@claims.com", | 				Email: "complex@claims.com", | ||||||
| 				Groups:            []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, | 				Groups: []string{ | ||||||
|  | 					"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", | ||||||
|  | 					"12345", | ||||||
|  | 					"Just::A::String", | ||||||
|  | 				}, | ||||||
| 				PreferredUsername: "Complex Claim", | 				PreferredUsername: "Complex Claim", | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
|  | @ -279,19 +300,25 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 				PreferredUsername: "Jane Dobbs", | 				PreferredUsername: "Jane Dobbs", | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		"User Claim Invalid": { | 		"User Claim switched to non string": { | ||||||
| 			IDToken:         defaultIDToken, | 			IDToken:         defaultIDToken, | ||||||
| 			AllowUnverified: true, | 			AllowUnverified: true, | ||||||
| 			UserClaim:       "groups", | 			UserClaim:       "roles", | ||||||
| 			EmailClaim:      "email", | 			EmailClaim:      "email", | ||||||
| 			GroupsClaim:     "groups", | 			GroupsClaim:     "groups", | ||||||
| 			ExpectedError:   errors.New("unable to extract custom UserClaim (groups)"), | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:              "[\"test:c\",\"test:d\"]", | ||||||
|  | 				Email:             "janed@me.com", | ||||||
|  | 				Groups:            []string{"test:a", "test:b"}, | ||||||
|  | 				PreferredUsername: "Jane Dobbs", | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		"Email Claim Switched": { | 		"Email Claim Switched": { | ||||||
| 			IDToken:         unverifiedIDToken, | 			IDToken:         unverifiedIDToken, | ||||||
| 			AllowUnverified: true, | 			AllowUnverified: true, | ||||||
| 			EmailClaim:      "phone_number", | 			EmailClaim:      "phone_number", | ||||||
| 			GroupsClaim:     "groups", | 			GroupsClaim:     "groups", | ||||||
|  | 			UserClaim:       "sub", | ||||||
| 			ExpectedSession: &sessions.SessionState{ | 			ExpectedSession: &sessions.SessionState{ | ||||||
| 				User:              "123456789", | 				User:              "123456789", | ||||||
| 				Email:             "+4025205729", | 				Email:             "+4025205729", | ||||||
|  | @ -304,9 +331,10 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 			AllowUnverified: true, | 			AllowUnverified: true, | ||||||
| 			EmailClaim:      "roles", | 			EmailClaim:      "roles", | ||||||
| 			GroupsClaim:     "groups", | 			GroupsClaim:     "groups", | ||||||
|  | 			UserClaim:       "sub", | ||||||
| 			ExpectedSession: &sessions.SessionState{ | 			ExpectedSession: &sessions.SessionState{ | ||||||
| 				User:              "123456789", | 				User:              "123456789", | ||||||
| 				Email:             "[test:c test:d]", | 				Email:             "[\"test:c\",\"test:d\"]", | ||||||
| 				Groups:            []string{"test:a", "test:b"}, | 				Groups:            []string{"test:a", "test:b"}, | ||||||
| 				PreferredUsername: "Mystery Man", | 				PreferredUsername: "Mystery Man", | ||||||
| 			}, | 			}, | ||||||
|  | @ -316,6 +344,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 			AllowUnverified: true, | 			AllowUnverified: true, | ||||||
| 			EmailClaim:      "aksjdfhjksadh", | 			EmailClaim:      "aksjdfhjksadh", | ||||||
| 			GroupsClaim:     "groups", | 			GroupsClaim:     "groups", | ||||||
|  | 			UserClaim:       "sub", | ||||||
| 			ExpectedSession: &sessions.SessionState{ | 			ExpectedSession: &sessions.SessionState{ | ||||||
| 				User:              "123456789", | 				User:              "123456789", | ||||||
| 				Email:             "", | 				Email:             "", | ||||||
|  | @ -328,6 +357,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 			AllowUnverified: false, | 			AllowUnverified: false, | ||||||
| 			EmailClaim:      "email", | 			EmailClaim:      "email", | ||||||
| 			GroupsClaim:     "roles", | 			GroupsClaim:     "roles", | ||||||
|  | 			UserClaim:       "sub", | ||||||
| 			ExpectedSession: &sessions.SessionState{ | 			ExpectedSession: &sessions.SessionState{ | ||||||
| 				User:              "123456789", | 				User:              "123456789", | ||||||
| 				Email:             "janed@me.com", | 				Email:             "janed@me.com", | ||||||
|  | @ -340,6 +370,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 			AllowUnverified: false, | 			AllowUnverified: false, | ||||||
| 			EmailClaim:      "email", | 			EmailClaim:      "email", | ||||||
| 			GroupsClaim:     "alskdjfsalkdjf", | 			GroupsClaim:     "alskdjfsalkdjf", | ||||||
|  | 			UserClaim:       "sub", | ||||||
| 			ExpectedSession: &sessions.SessionState{ | 			ExpectedSession: &sessions.SessionState{ | ||||||
| 				User:              "123456789", | 				User:              "123456789", | ||||||
| 				Email:             "janed@me.com", | 				Email:             "janed@me.com", | ||||||
|  | @ -347,6 +378,32 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 				PreferredUsername: "Jane Dobbs", | 				PreferredUsername: "Jane Dobbs", | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
|  | 		"Groups Claim Numeric values": { | ||||||
|  | 			IDToken:         numericGroupsIDToken, | ||||||
|  | 			AllowUnverified: false, | ||||||
|  | 			EmailClaim:      "email", | ||||||
|  | 			GroupsClaim:     "groups", | ||||||
|  | 			UserClaim:       "sub", | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:              "123456789", | ||||||
|  | 				Email:             "janed@me.com", | ||||||
|  | 				Groups:            []string{"1", "2", "3"}, | ||||||
|  | 				PreferredUsername: "Jane Dobbs", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"Groups Claim string values": { | ||||||
|  | 			IDToken:         defaultIDToken, | ||||||
|  | 			AllowUnverified: false, | ||||||
|  | 			EmailClaim:      "email", | ||||||
|  | 			GroupsClaim:     "email", | ||||||
|  | 			UserClaim:       "sub", | ||||||
|  | 			ExpectedSession: &sessions.SessionState{ | ||||||
|  | 				User:              "123456789", | ||||||
|  | 				Email:             "janed@me.com", | ||||||
|  | 				Groups:            []string{"janed@me.com"}, | ||||||
|  | 				PreferredUsername: "Jane Dobbs", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| 	for testName, tc := range testCases { | 	for testName, tc := range testCases { | ||||||
| 		t.Run(testName, func(t *testing.T) { | 		t.Run(testName, func(t *testing.T) { | ||||||
|  | @ -371,10 +428,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { | ||||||
| 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | ||||||
| 			g.Expect(err).ToNot(HaveOccurred()) | 			g.Expect(err).ToNot(HaveOccurred()) | ||||||
| 
 | 
 | ||||||
| 			idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken) | 			ss, err := provider.buildSessionFromClaims(rawIDToken, "") | ||||||
| 			g.Expect(err).ToNot(HaveOccurred()) |  | ||||||
| 
 |  | ||||||
| 			ss, err := provider.buildSessionFromClaims(idToken) |  | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				g.Expect(err).To(Equal(tc.ExpectedError)) | 				g.Expect(err).To(Equal(tc.ExpectedError)) | ||||||
| 			} | 			} | ||||||
|  | @ -418,6 +472,12 @@ func TestProviderData_checkNonce(t *testing.T) { | ||||||
| 		t.Run(testName, func(t *testing.T) { | 		t.Run(testName, func(t *testing.T) { | ||||||
| 			g := NewWithT(t) | 			g := NewWithT(t) | ||||||
| 
 | 
 | ||||||
|  | 			// Ensure that the ID token in the session is valid (signed and contains a nonce)
 | ||||||
|  | 			// as the nonce claim is extracted to compare with the session nonce
 | ||||||
|  | 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | ||||||
|  | 			g.Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			tc.Session.IDToken = rawIDToken | ||||||
|  | 
 | ||||||
| 			verificationOptions := &internaloidc.IDTokenVerificationOptions{ | 			verificationOptions := &internaloidc.IDTokenVerificationOptions{ | ||||||
| 				AudienceClaims: []string{"aud"}, | 				AudienceClaims: []string{"aud"}, | ||||||
| 				ClientID:       oidcClientID, | 				ClientID:       oidcClientID, | ||||||
|  | @ -430,14 +490,7 @@ func TestProviderData_checkNonce(t *testing.T) { | ||||||
| 				), verificationOptions), | 				), verificationOptions), | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | 			if err := provider.checkNonce(tc.Session); err != nil { | ||||||
| 			g.Expect(err).ToNot(HaveOccurred()) |  | ||||||
| 
 |  | ||||||
| 			idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken) |  | ||||||
| 			g.Expect(err).ToNot(HaveOccurred()) |  | ||||||
| 
 |  | ||||||
| 			err = provider.checkNonce(tc.Session, idToken) |  | ||||||
| 			if err != nil { |  | ||||||
| 				g.Expect(err).To(Equal(tc.ExpectedError)) | 				g.Expect(err).To(Equal(tc.ExpectedError)) | ||||||
| 			} else { | 			} else { | ||||||
| 				g.Expect(err).ToNot(HaveOccurred()) | 				g.Expect(err).ToNot(HaveOccurred()) | ||||||
|  | @ -445,95 +498,3 @@ func TestProviderData_checkNonce(t *testing.T) { | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func TestProviderData_extractGroups(t *testing.T) { |  | ||||||
| 	testCases := map[string]struct { |  | ||||||
| 		Claims         map[string]interface{} |  | ||||||
| 		GroupsClaim    string |  | ||||||
| 		ExpectedGroups []string |  | ||||||
| 	}{ |  | ||||||
| 		"Standard String Groups": { |  | ||||||
| 			Claims: map[string]interface{}{ |  | ||||||
| 				"email":  "this@does.not.matter.com", |  | ||||||
| 				"groups": []interface{}{"three", "string", "groups"}, |  | ||||||
| 			}, |  | ||||||
| 			GroupsClaim:    "groups", |  | ||||||
| 			ExpectedGroups: []string{"three", "string", "groups"}, |  | ||||||
| 		}, |  | ||||||
| 		"Different Claim Name": { |  | ||||||
| 			Claims: map[string]interface{}{ |  | ||||||
| 				"email": "this@does.not.matter.com", |  | ||||||
| 				"roles": []interface{}{"three", "string", "roles"}, |  | ||||||
| 			}, |  | ||||||
| 			GroupsClaim:    "roles", |  | ||||||
| 			ExpectedGroups: []string{"three", "string", "roles"}, |  | ||||||
| 		}, |  | ||||||
| 		"Numeric Groups": { |  | ||||||
| 			Claims: map[string]interface{}{ |  | ||||||
| 				"email":  "this@does.not.matter.com", |  | ||||||
| 				"groups": []interface{}{1, 2, 3}, |  | ||||||
| 			}, |  | ||||||
| 			GroupsClaim:    "groups", |  | ||||||
| 			ExpectedGroups: []string{"1", "2", "3"}, |  | ||||||
| 		}, |  | ||||||
| 		"Complex Groups": { |  | ||||||
| 			Claims: map[string]interface{}{ |  | ||||||
| 				"email": "this@does.not.matter.com", |  | ||||||
| 				"groups": []interface{}{ |  | ||||||
| 					map[string]interface{}{ |  | ||||||
| 						"groupId": "Admin Group Id", |  | ||||||
| 						"roles":   []string{"Admin"}, |  | ||||||
| 					}, |  | ||||||
| 					12345, |  | ||||||
| 					"Just::A::String", |  | ||||||
| 				}, |  | ||||||
| 			}, |  | ||||||
| 			GroupsClaim: "groups", |  | ||||||
| 			ExpectedGroups: []string{ |  | ||||||
| 				"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", |  | ||||||
| 				"12345", |  | ||||||
| 				"Just::A::String", |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"Missing Groups Claim Returns Nil": { |  | ||||||
| 			Claims: map[string]interface{}{ |  | ||||||
| 				"email": "this@does.not.matter.com", |  | ||||||
| 			}, |  | ||||||
| 			GroupsClaim:    "groups", |  | ||||||
| 			ExpectedGroups: nil, |  | ||||||
| 		}, |  | ||||||
| 		"Non List Groups": { |  | ||||||
| 			Claims: map[string]interface{}{ |  | ||||||
| 				"email":  "this@does.not.matter.com", |  | ||||||
| 				"groups": "singleton", |  | ||||||
| 			}, |  | ||||||
| 			GroupsClaim:    "groups", |  | ||||||
| 			ExpectedGroups: []string{"singleton"}, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	for testName, tc := range testCases { |  | ||||||
| 		t.Run(testName, func(t *testing.T) { |  | ||||||
| 			g := NewWithT(t) |  | ||||||
| 
 |  | ||||||
| 			verificationOptions := &internaloidc.IDTokenVerificationOptions{ |  | ||||||
| 				AudienceClaims: []string{"aud"}, |  | ||||||
| 				ClientID:       oidcClientID, |  | ||||||
| 			} |  | ||||||
| 			provider := &ProviderData{ |  | ||||||
| 				Verifier: internaloidc.NewVerifier(oidc.NewVerifier( |  | ||||||
| 					oidcIssuer, |  | ||||||
| 					mockJWKS{}, |  | ||||||
| 					&oidc.Config{ClientID: oidcClientID}, |  | ||||||
| 				), verificationOptions), |  | ||||||
| 			} |  | ||||||
| 			provider.GroupsClaim = tc.GroupsClaim |  | ||||||
| 
 |  | ||||||
| 			groups := provider.extractGroups(tc.Claims) |  | ||||||
| 			if tc.ExpectedGroups != nil { |  | ||||||
| 				g.Expect(groups).To(Equal(tc.ExpectedGroups)) |  | ||||||
| 			} else { |  | ||||||
| 				g.Expect(groups).To(BeNil()) |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -6,7 +6,6 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -83,18 +82,3 @@ func formatGroup(rawGroup interface{}) (string, error) { | ||||||
| 	} | 	} | ||||||
| 	return string(jsonGroup), nil | 	return string(jsonGroup), nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // coerceArray extracts a field from simplejson.Json that might be a
 |  | ||||||
| // singleton or a list and coerces it into a list.
 |  | ||||||
| func coerceArray(sj *simplejson.Json, key string) []interface{} { |  | ||||||
| 	array, err := sj.Get(key).Array() |  | ||||||
| 	if err == nil { |  | ||||||
| 		return array |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	single := sj.Get(key).Interface() |  | ||||||
| 	if single == nil { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 	return []interface{}{single} |  | ||||||
| } |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue