From 10adb5c516b5a15756a7baa50aa2d8551a6655b8 Mon Sep 17 00:00:00 2001 From: Ben Letton Date: Thu, 6 Feb 2020 18:09:30 +0000 Subject: [PATCH] Support optional id_tokens in refresh responses (#335) * OIDC Token Refresh works without id_tokens Addresses https://github.com/pusher/oauth2_proxy/issues/318 Refactoring the OIDC provider so that the refresh process works when there are no id_tokens present in the response. Added unit tests to the oidc_test.go to prove the redeem and refresh still work. The expiry time of the session is now taken from the outh token expiry and not the id_token (preventing stale access_tokens in sessions). * Refactoring the to use a KeySetStub in the oidc_test.go. This allows the elimination of the slightly contrived function passing elements used previously. (This change is being applied to address the bug #318) * Changes as per the PR comments and preparing for 5.x release * Fixup changelog Co-authored-by: Joel Speed --- CHANGELOG.md | 3 +- providers/oidc.go | 167 +++++++++++++++++--------- providers/oidc_test.go | 264 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 374 insertions(+), 60 deletions(-) create mode 100644 providers/oidc_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index ca3c8817..510d3fca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,11 +3,13 @@ ## Release Hightlights ## Important Notes +- [#335] The session expiry for the OIDC provider is now taken from the Token Response (expires_in) rather than from the id_token (exp) ## Breaking Changes ## Changes since v5.0.0 +- [#335](https://github.com/pusher/oauth2_proxy/pull/335) OIDC Provider support for empty id_tokens in the access token refresh response (@howzat) - [#363](https://github.com/pusher/oauth2_proxy/pull/363) Extension of Redis Session Store to Support Redis Cluster (@yan-dblinf) - [#353](https://github.com/pusher/oauth2_proxy/pull/353) Fix login page fragment handling after soft reload on Firefox (@ffdybuster) @@ -37,7 +39,6 @@ - [#179](https://github.com/pusher/oauth2_proxy/pull/179) Add Nextcloud provider (@Ramblurr) - [#280](https://github.com/pusher/oauth2_proxy/pull/280) whitelisted redirect domains: add support for whitelisting specific ports or allowing wildcard ports (@kamaln7) - [#351](https://github.com/pusher/oauth2_proxy/pull/351) Add DigitalOcean Auth provider (@kamaln7) - # v4.1.0 ## Release Highlights diff --git a/providers/oidc.go b/providers/oidc.go index 2abf2cae..06b3206c 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "time" oidc "github.com/coreos/go-oidc" @@ -42,28 +43,36 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat if err != nil { return nil, fmt.Errorf("token exchange: %v", err) } - s, err = p.createSessionState(ctx, token) + + // in the initial exchange the id token is mandatory + idToken, err := p.findVerifiedIDToken(ctx, token) + if err != nil { + return nil, fmt.Errorf("could not verify id_token: %v", err) + } else if idToken == nil { + return nil, fmt.Errorf("token response did not contain an id_token") + } + + s, err = p.createSessionState(token, idToken) if err != nil { return nil, fmt.Errorf("unable to update session: %v", err) } + return } // RefreshSessionIfNeeded checks if the session has expired and uses the -// RefreshToken to fetch a new ID token if required +// RefreshToken to fetch a new Access Token (and optional ID token) if required func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } - origExpiration := s.ExpiresOn - err := p.redeemRefreshToken(s) if err != nil { return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) + fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn) return true, nil } @@ -84,81 +93,76 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) if err != nil { return fmt.Errorf("failed to get token: %v", err) } - newSession, err := p.createSessionState(ctx, token) + + // in the token refresh response the id_token is optional + idToken, err := p.findVerifiedIDToken(ctx, token) if err != nil { - return fmt.Errorf("unable to update session: %v", err) + return fmt.Errorf("unable to extract id_token from response: %v", err) } + + newSession, err := p.createSessionState(token, idToken) + if err != nil { + return fmt.Errorf("unable create new session state from response: %v", err) + } + + // It's possible that if the refresh token isn't in the token response the session will not contain an id token + // if it doesn't it's probably better to retain the old one + if newSession.IDToken != "" { + s.IDToken = newSession.IDToken + s.Email = newSession.Email + s.User = newSession.User + } + s.AccessToken = newSession.AccessToken - s.IDToken = newSession.IDToken s.RefreshToken = newSession.RefreshToken s.CreatedAt = newSession.CreatedAt s.ExpiresOn = newSession.ExpiresOn - s.Email = newSession.Email + return } -func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return nil, fmt.Errorf("token response did not contain an id_token") +func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { + + getIDToken := func() (string, bool) { + rawIDToken, _ := token.Extra("id_token").(string) + return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0 } - // Parse and verify ID Token payload. - idToken, err := p.Verifier.Verify(ctx, rawIDToken) - if err != nil { - return nil, fmt.Errorf("could not verify id_token: %v", err) + if rawIDToken, present := getIDToken(); present { + verifiedIdToken, err := p.Verifier.Verify(ctx, rawIDToken) + return verifiedIdToken, err + } else { + return nil, nil } +} - // Extract custom claims. - var claims struct { - Subject string `json:"sub"` - Email string `json:"email"` - Verified *bool `json:"email_verified"` - } - if err := idToken.Claims(&claims); err != nil { - return nil, fmt.Errorf("failed to parse id_token claims: %v", err) - } +func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { - if claims.Email == "" { - if p.ProfileURL.String() == "" { - return nil, fmt.Errorf("id_token did not contain an email") - } + newSession := &sessions.SessionState{} - // If the userinfo endpoint profileURL is defined, then there is a chance the userinfo - // contents at the profileURL contains the email. - // Make a query to the userinfo endpoint, and attempt to locate the email from there. - - req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) + if idToken != nil { + claims, err := findClaimsFromIDToken(idToken, token.AccessToken, p.ProfileURL.String()) if err != nil { - return nil, err - } - req.Header = getOIDCHeader(token.AccessToken) - - respJSON, err := requests.Request(req) - if err != nil { - return nil, err + return nil, fmt.Errorf("couldn't extract claims from id_token (%e)", err) } - email, err := respJSON.Get("email").String() - if err != nil { - return nil, fmt.Errorf("Neither id_token nor userinfo endpoint contained an email") - } + if claims != nil { - claims.Email = email - } - if !p.AllowUnverifiedEmail && claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + if !p.AllowUnverifiedEmail && claims.Verified != nil && !*claims.Verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + } + + newSession.IDToken = token.Extra("id_token").(string) + newSession.Email = claims.Email + newSession.User = claims.Subject + } } - return &sessions.SessionState{ - AccessToken: token.AccessToken, - IDToken: rawIDToken, - RefreshToken: token.RefreshToken, - CreatedAt: time.Now(), - ExpiresOn: idToken.Expiry, - Email: claims.Email, - User: claims.Subject, - }, nil + newSession.AccessToken = token.AccessToken + newSession.RefreshToken = token.RefreshToken + newSession.CreatedAt = time.Now() + newSession.ExpiresOn = token.Expiry + return newSession, nil } // ValidateSessionState checks that the session's IDToken is still valid @@ -178,3 +182,48 @@ func getOIDCHeader(accessToken string) http.Header { header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) return header } + +func findClaimsFromIDToken(idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) { + + // Extract custom claims. + claims := &OIDCClaims{} + if err := idToken.Claims(claims); err != nil { + return nil, fmt.Errorf("failed to parse id_token claims: %v", err) + } + + if claims.Email == "" { + if profileURL == "" { + return nil, fmt.Errorf("id_token did not contain an email") + } + + // If the userinfo endpoint profileURL is defined, then there is a chance the userinfo + // contents at the profileURL contains the email. + // Make a query to the userinfo endpoint, and attempt to locate the email from there. + + req, err := http.NewRequest("GET", profileURL, nil) + if err != nil { + return nil, err + } + req.Header = getOIDCHeader(accessToken) + + respJSON, err := requests.Request(req) + if err != nil { + return nil, err + } + + email, err := respJSON.Get("email").String() + if err != nil { + return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained an email") + } + + claims.Email = email + } + + return claims, nil +} + +type OIDCClaims struct { + Subject string `json:"sub"` + Email string `json:"email"` + Verified *bool `json:"email_verified"` +} diff --git a/providers/oidc_test.go b/providers/oidc_test.go new file mode 100644 index 00000000..865aac85 --- /dev/null +++ b/providers/oidc_test.go @@ -0,0 +1,264 @@ +package providers + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "golang.org/x/oauth2" + + "github.com/bmizerany/assert" + "github.com/coreos/go-oidc" + "github.com/dgrijalva/jwt-go" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +const accessToken = "access_token" +const refreshToken = "refresh_token" +const clientID = "https://test.myapp.com" +const secret = "secret" + +type idTokenClaims struct { + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + Picture string `json:"picture,omitempty"` + jwt.StandardClaims +} + +type redeemTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + TokenType string `json:"token_type"` + IDToken string `json:"id_token,omitempty"` +} + +var defaultIDToken idTokenClaims = idTokenClaims{ + "Jane Dobbs", + "janed@me.com", + "http://mugbook.com/janed/me.jpg", + jwt.StandardClaims{ + Audience: "https://test.myapp.com", + ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), + Id: "id-some-id", + IssuedAt: time.Now().Unix(), + Issuer: "https://issuer.example.com", + NotBefore: 0, + Subject: "123456789", + }, +} + +type fakeKeySetStub struct {} + +func (fakeKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { + decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1]) + tokenClaims := &idTokenClaims{} + err = json.Unmarshal(decodeString, tokenClaims) + + if err != nil || tokenClaims.Id == "this-id-fails-validation" { + return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject) + } + + return decodeString, err +} + +func newOIDCProvider(serverURL *url.URL) *OIDCProvider { + + providerData := &ProviderData{ + ProviderName: "oidc", + ClientID: clientID, + ClientSecret: secret, + LoginURL: &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/login/oauth/authorize"}, + RedeemURL: &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/login/oauth/access_token"}, + ProfileURL: &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/profile"}, + ValidateURL: &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/api"}, + Scope: "openid profile offline_access"} + + p := &OIDCProvider{ + ProviderData: providerData, + Verifier: oidc.NewVerifier( + "https://issuer.example.com", + fakeKeySetStub{}, + &oidc.Config{ClientID: clientID}, + ), + } + + return p +} + +func newOIDCServer(body []byte) (*url.URL, *httptest.Server) { + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Add("content-type", "application/json") + _, _ = rw.Write(body) + })) + u, _ := url.Parse(s.URL) + return u, s +} + +func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) { + + key, _ := rsa.GenerateKey(rand.Reader, 2048) + standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims) + return standardClaims.SignedString(key) +} + +func newOauth2Token() *oauth2.Token { + return &oauth2.Token{ + AccessToken: accessToken, + TokenType: "Bearer", + RefreshToken: refreshToken, + Expiry: time.Time{}.Add(time.Duration(5) * time.Second), + } +} + +func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) { + redeemURL, server := newOIDCServer(body) + provider := newOIDCProvider(redeemURL) + return server, provider +} + +func TestOIDCProviderRedeem(t *testing.T) { + + idToken, _ := newSignedTestIDToken(defaultIDToken) + body, _ := json.Marshal(redeemTokenResponse{ + AccessToken: accessToken, + ExpiresIn: 10, + TokenType: "Bearer", + RefreshToken: refreshToken, + IDToken: idToken, + }) + + server, provider := newTestSetup(body) + defer server.Close() + + session, err := provider.Redeem(provider.RedeemURL.String(), "code1234") + assert.Equal(t, nil, err) + assert.Equal(t, defaultIDToken.Email, session.Email) + assert.Equal(t, accessToken, session.AccessToken) + assert.Equal(t, idToken, session.IDToken) + assert.Equal(t, refreshToken, session.RefreshToken) + assert.Equal(t, "123456789", session.User) +} + +func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { + + idToken, _ := newSignedTestIDToken(defaultIDToken) + body, _ := json.Marshal(redeemTokenResponse{ + AccessToken: accessToken, + ExpiresIn: 10, + TokenType: "Bearer", + RefreshToken: refreshToken, + }) + + server, provider := newTestSetup(body) + defer server.Close() + + existingSession := &sessions.SessionState{ + AccessToken: "changeit", + IDToken: idToken, + CreatedAt: time.Time{}, + ExpiresOn: time.Time{}, + RefreshToken: refreshToken, + Email: "janedoe@example.com", + User: "11223344", + } + + refreshed, err := provider.RefreshSessionIfNeeded(existingSession) + assert.Equal(t, nil, err) + assert.Equal(t, refreshed, true) + assert.Equal(t, "janedoe@example.com", existingSession.Email) + assert.Equal(t, accessToken, existingSession.AccessToken) + assert.Equal(t, idToken, existingSession.IDToken) + assert.Equal(t, refreshToken, existingSession.RefreshToken) + assert.Equal(t, "11223344", existingSession.User) +} + +func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { + + idToken, _ := newSignedTestIDToken(defaultIDToken) + body, _ := json.Marshal(redeemTokenResponse{ + AccessToken: accessToken, + ExpiresIn: 10, + TokenType: "Bearer", + RefreshToken: refreshToken, + IDToken: idToken, + }) + + server, provider := newTestSetup(body) + defer server.Close() + + existingSession := &sessions.SessionState{ + AccessToken: "changeit", + IDToken: "changeit", + CreatedAt: time.Time{}, + ExpiresOn: time.Time{}, + RefreshToken: refreshToken, + Email: "changeit", + User: "changeit", + } + refreshed, err := provider.RefreshSessionIfNeeded(existingSession) + assert.Equal(t, nil, err) + assert.Equal(t, refreshed, true) + assert.Equal(t, defaultIDToken.Email, existingSession.Email) + assert.Equal(t, defaultIDToken.Subject, existingSession.User) + assert.Equal(t, accessToken, existingSession.AccessToken) + assert.Equal(t, idToken, existingSession.IDToken) + assert.Equal(t, refreshToken, existingSession.RefreshToken) +} + +func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { + + server, provider := newTestSetup([]byte("")) + + defer server.Close() + + token := newOauth2Token() + signedIdToken, _ := newSignedTestIDToken(defaultIDToken) + tokenWithIdToken := token.WithExtra(map[string]interface{}{ + "id_token": signedIdToken, + }) + + verifiedIdToken, err := provider.findVerifiedIDToken(context.Background(), tokenWithIdToken) + assert.Equal(t, true, err == nil) + assert.Equal(t, true, verifiedIdToken != nil) + assert.Equal(t, defaultIDToken.Issuer, verifiedIdToken.Issuer) + assert.Equal(t, defaultIDToken.Subject, verifiedIdToken.Subject) + + // When the validation fails the response should be nil + defaultIDToken.Id = "this-id-fails-validation" + signedIdToken, _ = newSignedTestIDToken(defaultIDToken) + tokenWithIdToken = token.WithExtra(map[string]interface{}{ + "id_token": signedIdToken, + }) + + verifiedIdToken, err = provider.findVerifiedIDToken(context.Background(), tokenWithIdToken) + assert.Equal(t, errors.New("failed to verify signature: the validation failed for subject [123456789]"), err) + assert.Equal(t, true, verifiedIdToken == nil) + + // When there is no id token in the oauth token + verifiedIdToken, err = provider.findVerifiedIDToken(context.Background(), newOauth2Token()) + assert.Equal(t, nil, err) + assert.Equal(t, true, verifiedIdToken == nil) +}