From f3209a40e16defdab442ef87765e837d90ec4f9a Mon Sep 17 00:00:00 2001 From: Weinong Wang Date: Tue, 9 Mar 2021 20:53:15 -0800 Subject: [PATCH] extract email from id_token for azure provider (#914) * extract email from id_token for azure provider this change fixes a bug when --resource is specified with non-Graph api and the access token destined to --resource is used to call Graph api * fixed typo * refactor GetEmailAddress to EnrichSessionState * make getting email from idtoken best effort and fall back to previous behavior when it's absent * refactor to use jwt package to extract claims * fix lint * refactor unit tests to use test table refactor the get email logic from profile api * addressing feedback * added oidc verifier to azure provider and extract email from id_token if present * fix lint and codeclimate * refactor to use oidc verifier to verify id_token if oidc is configured * fixed UT * addressed comments * minor refactor * addressed feedback * extract email from id_token first and fallback to access token * fallback to access token as well when id_token doesn't have email claim * address feedbacks * updated change log! --- CHANGELOG.md | 1 + providers/azure.go | 180 ++++++++++++++++------ providers/azure_test.go | 329 +++++++++++++++++++++++++++------------- 3 files changed, 355 insertions(+), 155 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 02fb5e5a..8515e5d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.0.1 +- [#914](https://github.com/oauth2-proxy/oauth2-proxy/pull/914) Extract email from id_token for azure provider when oidc is configured - [#1047](https://github.com/oauth2-proxy/oauth2-proxy/pull/1047) Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers (@JoelSpeed) - [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves) - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) diff --git a/providers/azure.go b/providers/azure.go index 92974540..036fefae 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -107,26 +107,22 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) { } } +func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { + extraParams := url.Values{} + if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { + extraParams.Add("resource", p.ProtectedResource.String()) + } + a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) + return a.String() +} + // Redeem exchanges the OAuth2 authentication token for an ID token func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { - if code == "" { - return nil, ErrMissingCode - } - clientSecret, err := p.GetClientSecret() + params, err := p.prepareRedeem(redirectURL, code) if err != nil { return nil, err } - params := url.Values{} - params.Add("redirect_uri", redirectURL) - params.Add("client_id", p.ClientID) - params.Add("client_secret", clientSecret) - params.Add("code", code) - params.Add("grant_type", "authorization_code") - if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { - params.Add("resource", p.ProtectedResource.String()) - } - // blindly try json and x-www-form-urlencoded var jsonResponse struct { AccessToken string `json:"access_token"` @@ -149,13 +145,98 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* created := time.Now() expires := time.Unix(jsonResponse.ExpiresOn, 0) - return &sessions.SessionState{ + session := &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, CreatedAt: &created, ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, - }, nil + } + + email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) + + // https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814 + // https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117 + // due to above issues, id_token may not be signed by AAD + // in that case, we will fallback to access token + if err == nil && email != "" { + session.Email = email + } else { + logger.Printf("unable to get email claim from id_token: %v", err) + } + + if session.Email == "" { + email, err = p.verifyTokenAndExtractEmail(ctx, session.AccessToken) + if err == nil && email != "" { + session.Email = email + } else { + logger.Printf("unable to get email claim from access token: %v", err) + } + } + + return session, nil +} + +// EnrichSession finds the email to enrich the session state +func (p *AzureProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { + if s.Email != "" { + return nil + } + + email, err := p.getEmailFromProfileAPI(ctx, s.AccessToken) + if err != nil { + return fmt.Errorf("unable to get email address: %v", err) + } + if email == "" { + return errors.New("unable to get email address") + } + s.Email = email + + return nil +} + +func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, error) { + params := url.Values{} + if code == "" { + return params, ErrMissingCode + } + clientSecret, err := p.GetClientSecret() + if err != nil { + return params, err + } + + params.Add("redirect_uri", redirectURL) + params.Add("client_id", p.ClientID) + params.Add("client_secret", clientSecret) + params.Add("code", code) + params.Add("grant_type", "authorization_code") + if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { + params.Add("resource", p.ProtectedResource.String()) + } + return params, nil +} + +// verifyTokenAndExtractEmail tries to extract email claim from either id_token or access token +// when oidc verifier is configured +func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token string) (string, error) { + email := "" + + if token != "" && p.Verifier != nil { + token, err := p.Verifier.Verify(ctx, token) + // due to issues mentioned above, id_token may not be signed by AAD + if err == nil { + claims, err := p.getClaims(token) + if err == nil { + email = claims.Email + } else { + logger.Printf("unable to get claims from token: %v", err) + } + } else { + logger.Printf("unable to verify token: %v", err) + } + } + + return email, nil } // RefreshSessionIfNeeded checks if the session has expired and uses the @@ -209,6 +290,28 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess s.RefreshToken = jsonResponse.RefreshToken s.CreatedAt = &now s.ExpiresOn = &expires + + email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) + + // https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814 + // https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117 + // due to above issues, id_token may not be signed by AAD + // in that case, we will fallback to access token + if err == nil && email != "" { + s.Email = email + } else { + logger.Printf("unable to get email claim from id_token: %v", err) + } + + if s.Email == "" { + email, err = p.verifyTokenAndExtractEmail(ctx, s.AccessToken) + if err == nil && email != "" { + s.Email = email + } else { + logger.Printf("unable to get email claim from access token: %v", err) + } + } + return } @@ -230,53 +333,32 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) { err = otherMailsErr } + if err != nil || email == "" { + email, err = json.Get("userPrincipalName").String() + if err != nil { + logger.Errorf("unable to find userPrincipalName: %s", err) + return "", err + } + } + return email, err } -// GetEmailAddress returns the Account email address -func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - var email string - var err error - - if s.AccessToken == "" { +func (p *AzureProvider) getEmailFromProfileAPI(ctx context.Context, accessToken string) (string, error) { + if accessToken == "" { return "", errors.New("missing access token") } json, err := requests.New(p.ProfileURL.String()). WithContext(ctx). - WithHeaders(makeAzureHeader(s.AccessToken)). + WithHeaders(makeAzureHeader(accessToken)). Do(). UnmarshalJSON() if err != nil { return "", err } - email, err = getEmailFromJSON(json) - if err == nil && email != "" { - return email, err - } - - email, err = json.Get("userPrincipalName").String() - if err != nil { - logger.Errorf("failed making request %s", err) - return "", err - } - - if email == "" { - logger.Errorf("failed to get email address") - return "", err - } - - return email, err -} - -func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { - extraParams := url.Values{} - if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { - extraParams.Add("resource", p.ProtectedResource.String()) - } - a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) - return a.String() + return getEmailFromJSON(json) } // ValidateSession validates the AccessToken diff --git a/providers/azure_test.go b/providers/azure_test.go index 9e3cabf7..35509b0d 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -2,18 +2,41 @@ package providers import ( "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" + oidc "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) +type fakeAzureKeySetStub struct{} + +func (fakeAzureKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { + decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1]) + if err != nil { + return nil, err + } + return decodeString, nil +} + +type azureOAuthPayload struct { + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresOn int64 `json:"expires_on,omitempty,string"` + IDToken string `json:"id_token,omitempty"` +} + func testAzureProvider(hostname string) *AzureProvider { p := NewAzureProvider( &ProviderData{ @@ -23,7 +46,19 @@ func testAzureProvider(hostname string) *AzureProvider { ProfileURL: &url.URL{}, ValidateURL: &url.URL{}, ProtectedResource: &url.URL{}, - Scope: ""}) + Scope: "", + EmailClaim: "email", + Verifier: oidc.NewVerifier( + "https://issuer.example.com", + fakeAzureKeySetStub{}, + &oidc.Config{ + ClientID: "cd6d4fae-f6a6-4a34-8454-2c6b598e9532", + SkipClientIDCheck: true, + SkipIssuerCheck: true, + SkipExpiryCheck: true, + }, + ), + }) if hostname != "" { updateURL(p.Data().LoginURL, hostname) @@ -104,6 +139,10 @@ func TestAzureSetTenant(t *testing.T) { } func testAzureBackend(payload string) *httptest.Server { + return testAzureBackendWithError(payload, false) +} + +func testAzureBackendWithError(payload string, injectError bool) *httptest.Server { path := "/v1.0/me" return httptest.NewServer(http.HandlerFunc( @@ -111,7 +150,11 @@ func testAzureBackend(payload string) *httptest.Server { if (r.URL.Path != path) && r.Method != http.MethodPost { w.WriteHeader(404) } else if r.Method == http.MethodPost && r.Body != nil { - w.WriteHeader(200) + if injectError { + w.WriteHeader(400) + } else { + w.WriteHeader(200) + } w.Write([]byte(payload)) } else if !IsAuthorizedInHeader(r.Header) { w.WriteHeader(403) @@ -122,98 +165,172 @@ func testAzureBackend(payload string) *httptest.Server { })) } -func TestAzureProviderGetEmailAddress(t *testing.T) { - b := testAzureBackend(`{ "mail": "user@windows.net" }`) - defer b.Close() +func TestAzureProviderEnrichSession(t *testing.T) { + testCases := []struct { + Description string + Email string + PayloadFromAzureBackend string + ExpectedEmail string + ExpectedError error + }{ + { + Description: "should return email using mail property from Azure backend", + PayloadFromAzureBackend: `{ "mail": "user@windows.net" }`, + ExpectedEmail: "user@windows.net", + }, + { + Description: "should return email using otherMails property returned from Azure backend", + PayloadFromAzureBackend: `{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`, + ExpectedEmail: "user@windows.net", + }, + { + Description: "should return email using userPrincipalName from Azure backend", + PayloadFromAzureBackend: `{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`, + ExpectedEmail: "user@windows.net", + }, + { + Description: "should return error when Azure backend doesn't return email information", + PayloadFromAzureBackend: `{ "mail": null, "otherMails": [], "userPrincipalName": null }`, + ExpectedError: fmt.Errorf("unable to get email address: %v", errors.New("type assertion to string failed")), + }, + { + Description: "should return specific error when unable to get email", + PayloadFromAzureBackend: `{ "mail": null, "otherMails": [], "userPrincipalName": "" }`, + ExpectedError: errors.New("unable to get email address"), + }, + { + Description: "should return error when otherMails from Azure backend is not a valid type", + PayloadFromAzureBackend: `{ "mail": null, "otherMails": "", "userPrincipalName": null }`, + ExpectedError: fmt.Errorf("unable to get email address: %v", errors.New("type assertion to string failed")), + }, + { + Description: "should not query profile api when email is already set in session", + Email: "user@windows.net", + ExpectedEmail: "user@windows.net", + }, + } - bURL, _ := url.Parse(b.URL) - p := testAzureProvider(bURL.Host) + for _, testCase := range testCases { + t.Run(testCase.Description, func(t *testing.T) { + var ( + b *httptest.Server + host string + ) + if testCase.PayloadFromAzureBackend != "" { + b = testAzureBackend(testCase.PayloadFromAzureBackend) + defer b.Close() - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "user@windows.net", email) + bURL, _ := url.Parse(b.URL) + host = bURL.Host + } + p := testAzureProvider(host) + session := CreateAuthorizedSession() + session.Email = testCase.Email + err := p.EnrichSession(context.Background(), session) + assert.Equal(t, testCase.ExpectedError, err) + assert.Equal(t, testCase.ExpectedEmail, session.Email) + }) + } } -func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { - b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`) - defer b.Close() +func TestAzureProviderRedeem(t *testing.T) { + testCases := []struct { + Name string + RefreshToken string + ExpiresOn time.Time + EmailFromIDToken string + EmailFromAccessToken string + IsIDTokenMalformed bool + InjectRedeemURLError bool + }{ + { + Name: "with id_token returned", + EmailFromIDToken: "foo1@example.com", + RefreshToken: "some_refresh_token", + ExpiresOn: time.Now().Add(time.Hour), + }, + { + Name: "without id_token returned, fallback to access token", + EmailFromAccessToken: "foo2@example.com", + RefreshToken: "some_refresh_token", + ExpiresOn: time.Now().Add(time.Hour), + }, + { + Name: "id_token malformed, fallback to access token", + EmailFromAccessToken: "foo3@example.com", + RefreshToken: "some_refresh_token", + ExpiresOn: time.Now().Add(time.Hour), + IsIDTokenMalformed: true, + }, + { + Name: "both id_token and access tokens are valid, return email from id_token", + EmailFromIDToken: "foo1@example.com", + EmailFromAccessToken: "foo3@example.com", + RefreshToken: "some_refresh_token", + ExpiresOn: time.Now().Add(time.Hour), + }, + { + Name: "redeem URL failed, should return error", + EmailFromIDToken: "foo1@example.com", + EmailFromAccessToken: "foo3@example.com", + RefreshToken: "some_refresh_token", + ExpiresOn: time.Now().Add(time.Hour), + InjectRedeemURLError: true, + }, + } - bURL, _ := url.Parse(b.URL) - p := testAzureProvider(bURL.Host) + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + idTokenString := "" + accessTokenString := "" + if testCase.EmailFromIDToken != "" { + var err error + token := idTokenClaims{Email: testCase.EmailFromIDToken} + idTokenString, err = newSignedTestIDToken(token) + assert.NoError(t, err) + } + if testCase.EmailFromAccessToken != "" { + var err error + token := idTokenClaims{Email: testCase.EmailFromAccessToken} + accessTokenString, err = newSignedTestIDToken(token) + assert.NoError(t, err) + } + if testCase.IsIDTokenMalformed { + idTokenString = "this is a malformed id_token" + } + payload := azureOAuthPayload{ + IDToken: idTokenString, + RefreshToken: testCase.RefreshToken, + AccessToken: accessTokenString, + ExpiresOn: testCase.ExpiresOn.Unix(), + } - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "user@windows.net", email) -} + payloadBytes, err := json.Marshal(payload) + assert.NoError(t, err) -func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { - b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`) - defer b.Close() + b := testAzureBackendWithError(string(payloadBytes), testCase.InjectRedeemURLError) + defer b.Close() - bURL, _ := url.Parse(b.URL) - p := testAzureProvider(bURL.Host) - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "user@windows.net", email) -} - -func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { - b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`) - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testAzureProvider(bURL.Host) - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, "type assertion to string failed", err.Error()) - assert.Equal(t, "", email) -} - -func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { - b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`) - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testAzureProvider(bURL.Host) - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "", email) -} - -func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { - b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`) - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testAzureProvider(bURL.Host) - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, "type assertion to string failed", err.Error()) - assert.Equal(t, "", email) -} - -func TestAzureProviderRedeemReturnsIdToken(t *testing.T) { - b := testAzureBackend(`{ "id_token": "testtoken1234", "expires_on": "1136239445", "refresh_token": "refresh1234" }`) - defer b.Close() - timestamp, err := time.Parse(time.RFC3339, "2006-01-02T22:04:05Z") - assert.Equal(t, nil, err) - - bURL, _ := url.Parse(b.URL) - p := testAzureProvider(bURL.Host) - p.Data().RedeemURL.Path = "/common/oauth2/token" - s, err := p.Redeem(context.Background(), "https://localhost", "1234") - assert.Equal(t, nil, err) - assert.Equal(t, "testtoken1234", s.IDToken) - assert.Equal(t, timestamp, s.ExpiresOn.UTC()) - assert.Equal(t, "refresh1234", s.RefreshToken) + bURL, _ := url.Parse(b.URL) + p := testAzureProvider(bURL.Host) + p.Data().RedeemURL.Path = "/common/oauth2/token" + s, err := p.Redeem(context.Background(), "https://localhost", "1234") + if testCase.InjectRedeemURLError { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, idTokenString, s.IDToken) + assert.Equal(t, accessTokenString, s.AccessToken) + assert.Equal(t, testCase.ExpiresOn.Unix(), s.ExpiresOn.Unix()) + assert.Equal(t, testCase.RefreshToken, s.RefreshToken) + if testCase.EmailFromIDToken != "" { + assert.Equal(t, testCase.EmailFromIDToken, s.Email) + } else { + assert.Equal(t, testCase.EmailFromAccessToken, s.Email) + } + } + }) + } } func TestAzureProviderProtectedResourceConfigured(t *testing.T) { @@ -223,22 +340,6 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) } -func TestAzureProviderGetsTokensInRedeem(t *testing.T) { - b := testAzureBackend(`{ "access_token": "some_access_token", "refresh_token": "some_refresh_token", "expires_on": "1136239445", "id_token": "some_id_token" }`) - defer b.Close() - timestamp, _ := time.Parse(time.RFC3339, "2006-01-02T22:04:05Z") - bURL, _ := url.Parse(b.URL) - p := testAzureProvider(bURL.Host) - - session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") - assert.Equal(t, nil, err) - assert.NotEqual(t, session, nil) - assert.Equal(t, "some_access_token", session.AccessToken) - assert.Equal(t, "some_refresh_token", session.RefreshToken) - assert.Equal(t, "some_id_token", session.IDToken) - assert.Equal(t, timestamp, session.ExpiresOn.UTC()) -} - func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { p := testAzureProvider("") @@ -250,19 +351,35 @@ func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { } func TestAzureProviderRefreshWhenExpired(t *testing.T) { - b := testAzureBackend(`{ "access_token": "new_some_access_token", "refresh_token": "new_some_refresh_token", "expires_on": "32693148245", "id_token": "new_some_id_token" }`) + email := "foo@example.com" + idToken := idTokenClaims{Email: email} + idTokenString, err := newSignedTestIDToken(idToken) + assert.NoError(t, err) + timestamp, err := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") + assert.NoError(t, err) + payload := azureOAuthPayload{ + IDToken: idTokenString, + RefreshToken: "new_some_refresh_token", + AccessToken: "new_some_access_token", + ExpiresOn: timestamp.Unix(), + } + + payloadBytes, err := json.Marshal(payload) + assert.NoError(t, err) + b := testAzureBackend(string(payloadBytes)) defer b.Close() - timestamp, _ := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) expires := time.Now().Add(time.Duration(-1) * time.Hour) session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} - _, err := p.RefreshSessionIfNeeded(context.Background(), session) + refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) assert.Equal(t, nil, err) + assert.True(t, refreshNeeded) assert.NotEqual(t, session, nil) assert.Equal(t, "new_some_access_token", session.AccessToken) assert.Equal(t, "new_some_refresh_token", session.RefreshToken) - assert.Equal(t, "new_some_id_token", session.IDToken) + assert.Equal(t, idTokenString, session.IDToken) + assert.Equal(t, email, session.Email) assert.Equal(t, timestamp, session.ExpiresOn.UTC()) }