Support optional id_tokens in refresh responses (#335)
* OIDC Token Refresh works without id_tokens Addresses https://github.com/pusher/oauth2_proxy/issues/318 Refactoring the OIDC provider so that the refresh process works when there are no id_tokens present in the response. Added unit tests to the oidc_test.go to prove the redeem and refresh still work. The expiry time of the session is now taken from the outh token expiry and not the id_token (preventing stale access_tokens in sessions). * Refactoring the to use a KeySetStub in the oidc_test.go. This allows the elimination of the slightly contrived function passing elements used previously. (This change is being applied to address the bug #318) * Changes as per the PR comments and preparing for 5.x release * Fixup changelog Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
This commit is contained in:
		
							parent
							
								
									18d20364a8
								
							
						
					
					
						commit
						10adb5c516
					
				|  | @ -3,11 +3,13 @@ | ||||||
| ## Release Hightlights | ## Release Hightlights | ||||||
| 
 | 
 | ||||||
| ## Important Notes | ## Important Notes | ||||||
|  | - [#335] The session expiry for the OIDC provider is now taken from the Token Response (expires_in) rather than from the id_token (exp)  | ||||||
| 
 | 
 | ||||||
| ## Breaking Changes | ## Breaking Changes | ||||||
| 
 | 
 | ||||||
| ## Changes since v5.0.0 | ## Changes since v5.0.0 | ||||||
| 
 | 
 | ||||||
|  | - [#335](https://github.com/pusher/oauth2_proxy/pull/335) OIDC Provider support for empty id_tokens in the access token refresh response (@howzat) | ||||||
| - [#363](https://github.com/pusher/oauth2_proxy/pull/363) Extension of Redis Session Store to Support Redis Cluster (@yan-dblinf) | - [#363](https://github.com/pusher/oauth2_proxy/pull/363) Extension of Redis Session Store to Support Redis Cluster (@yan-dblinf) | ||||||
| - [#353](https://github.com/pusher/oauth2_proxy/pull/353) Fix login page fragment handling after soft reload on Firefox (@ffdybuster) | - [#353](https://github.com/pusher/oauth2_proxy/pull/353) Fix login page fragment handling after soft reload on Firefox (@ffdybuster) | ||||||
| 
 | 
 | ||||||
|  | @ -37,7 +39,6 @@ | ||||||
| - [#179](https://github.com/pusher/oauth2_proxy/pull/179) Add Nextcloud provider (@Ramblurr) | - [#179](https://github.com/pusher/oauth2_proxy/pull/179) Add Nextcloud provider (@Ramblurr) | ||||||
| - [#280](https://github.com/pusher/oauth2_proxy/pull/280) whitelisted redirect domains: add support for whitelisting specific ports or allowing wildcard ports (@kamaln7) | - [#280](https://github.com/pusher/oauth2_proxy/pull/280) whitelisted redirect domains: add support for whitelisting specific ports or allowing wildcard ports (@kamaln7) | ||||||
| - [#351](https://github.com/pusher/oauth2_proxy/pull/351) Add DigitalOcean Auth provider (@kamaln7) | - [#351](https://github.com/pusher/oauth2_proxy/pull/351) Add DigitalOcean Auth provider (@kamaln7) | ||||||
| 
 |  | ||||||
| # v4.1.0 | # v4.1.0 | ||||||
| 
 | 
 | ||||||
| ## Release Highlights | ## Release Highlights | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	oidc "github.com/coreos/go-oidc" | ||||||
|  | @ -42,28 +43,36 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("token exchange: %v", err) | 		return nil, fmt.Errorf("token exchange: %v", err) | ||||||
| 	} | 	} | ||||||
| 	s, err = p.createSessionState(ctx, token) | 
 | ||||||
|  | 	// in the initial exchange the id token is mandatory
 | ||||||
|  | 	idToken, err := p.findVerifiedIDToken(ctx, token) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not verify id_token: %v", err) | ||||||
|  | 	} else if idToken == nil { | ||||||
|  | 		return nil, fmt.Errorf("token response did not contain an id_token") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s, err = p.createSessionState(token, idToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("unable to update session: %v", err) | 		return nil, fmt.Errorf("unable to update session: %v", err) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | // RefreshToken to fetch a new Access Token (and optional ID token) if required
 | ||||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { | func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { | ||||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	origExpiration := s.ExpiresOn |  | ||||||
| 
 |  | ||||||
| 	err := p.redeemRefreshToken(s) | 	err := p.redeemRefreshToken(s) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) | 	fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn) | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -84,81 +93,76 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to get token: %v", err) | 		return fmt.Errorf("failed to get token: %v", err) | ||||||
| 	} | 	} | ||||||
| 	newSession, err := p.createSessionState(ctx, token) | 
 | ||||||
|  | 	// in the token refresh response the id_token is optional
 | ||||||
|  | 	idToken, err := p.findVerifiedIDToken(ctx, token) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("unable to update session: %v", err) | 		return fmt.Errorf("unable to extract id_token from response: %v", err) | ||||||
| 	} | 	} | ||||||
| 	s.AccessToken = newSession.AccessToken | 
 | ||||||
|  | 	newSession, err := p.createSessionState(token, idToken) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("unable create new session state from response: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// It's possible that if the refresh token isn't in the token response the session will not contain an id token
 | ||||||
|  | 	// if it doesn't it's probably better to retain the old one
 | ||||||
|  | 	if newSession.IDToken != "" { | ||||||
| 		s.IDToken = newSession.IDToken | 		s.IDToken = newSession.IDToken | ||||||
|  | 		s.Email = newSession.Email | ||||||
|  | 		s.User = newSession.User | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s.AccessToken = newSession.AccessToken | ||||||
| 	s.RefreshToken = newSession.RefreshToken | 	s.RefreshToken = newSession.RefreshToken | ||||||
| 	s.CreatedAt = newSession.CreatedAt | 	s.CreatedAt = newSession.CreatedAt | ||||||
| 	s.ExpiresOn = newSession.ExpiresOn | 	s.ExpiresOn = newSession.ExpiresOn | ||||||
| 	s.Email = newSession.Email | 
 | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { | func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { | ||||||
| 	rawIDToken, ok := token.Extra("id_token").(string) | 
 | ||||||
| 	if !ok { | 	getIDToken := func() (string, bool) { | ||||||
| 		return nil, fmt.Errorf("token response did not contain an id_token") | 		rawIDToken, _ := token.Extra("id_token").(string) | ||||||
|  | 		return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0 | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Parse and verify ID Token payload.
 | 	if rawIDToken, present := getIDToken(); present { | ||||||
| 	idToken, err := p.Verifier.Verify(ctx, rawIDToken) | 		verifiedIdToken, err := p.Verifier.Verify(ctx, rawIDToken) | ||||||
|  | 		return verifiedIdToken, err | ||||||
|  | 	} else { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *OIDCProvider) createSessionState(token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { | ||||||
|  | 
 | ||||||
|  | 	newSession := &sessions.SessionState{} | ||||||
|  | 
 | ||||||
|  | 	if idToken != nil { | ||||||
|  | 		claims, err := findClaimsFromIDToken(idToken, token.AccessToken, p.ProfileURL.String()) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 		return nil, fmt.Errorf("could not verify id_token: %v", err) | 			return nil, fmt.Errorf("couldn't extract claims from id_token (%e)", err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 	// Extract custom claims.
 | 		if claims != nil { | ||||||
| 	var claims struct { |  | ||||||
| 		Subject  string `json:"sub"` |  | ||||||
| 		Email    string `json:"email"` |  | ||||||
| 		Verified *bool  `json:"email_verified"` |  | ||||||
| 	} |  | ||||||
| 	if err := idToken.Claims(&claims); err != nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to parse id_token claims: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if claims.Email == "" { |  | ||||||
| 		if p.ProfileURL.String() == "" { |  | ||||||
| 			return nil, fmt.Errorf("id_token did not contain an email") |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
 |  | ||||||
| 		// contents at the profileURL contains the email.
 |  | ||||||
| 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 |  | ||||||
| 
 |  | ||||||
| 		req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		req.Header = getOIDCHeader(token.AccessToken) |  | ||||||
| 
 |  | ||||||
| 		respJSON, err := requests.Request(req) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		email, err := respJSON.Get("email").String() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, fmt.Errorf("Neither id_token nor userinfo endpoint contained an email") |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		claims.Email = email |  | ||||||
| 	} |  | ||||||
| 			if !p.AllowUnverifiedEmail && claims.Verified != nil && !*claims.Verified { | 			if !p.AllowUnverifiedEmail && claims.Verified != nil && !*claims.Verified { | ||||||
| 				return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | 				return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 	return &sessions.SessionState{ | 			newSession.IDToken = token.Extra("id_token").(string) | ||||||
| 		AccessToken:  token.AccessToken, | 			newSession.Email = claims.Email | ||||||
| 		IDToken:      rawIDToken, | 			newSession.User = claims.Subject | ||||||
| 		RefreshToken: token.RefreshToken, | 		} | ||||||
| 		CreatedAt:    time.Now(), | 	} | ||||||
| 		ExpiresOn:    idToken.Expiry, | 
 | ||||||
| 		Email:        claims.Email, | 	newSession.AccessToken = token.AccessToken | ||||||
| 		User:         claims.Subject, | 	newSession.RefreshToken = token.RefreshToken | ||||||
| 	}, nil | 	newSession.CreatedAt = time.Now() | ||||||
|  | 	newSession.ExpiresOn = token.Expiry | ||||||
|  | 	return newSession, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState checks that the session's IDToken is still valid
 | // ValidateSessionState checks that the session's IDToken is still valid
 | ||||||
|  | @ -178,3 +182,48 @@ func getOIDCHeader(accessToken string) http.Header { | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||||
| 	return header | 	return header | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func findClaimsFromIDToken(idToken *oidc.IDToken, accessToken string, profileURL string) (*OIDCClaims, error) { | ||||||
|  | 
 | ||||||
|  | 	// Extract custom claims.
 | ||||||
|  | 	claims := &OIDCClaims{} | ||||||
|  | 	if err := idToken.Claims(claims); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to parse id_token claims: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if claims.Email == "" { | ||||||
|  | 		if profileURL == "" { | ||||||
|  | 			return nil, fmt.Errorf("id_token did not contain an email") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
 | ||||||
|  | 		// contents at the profileURL contains the email.
 | ||||||
|  | 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | ||||||
|  | 
 | ||||||
|  | 		req, err := http.NewRequest("GET", profileURL, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		req.Header = getOIDCHeader(accessToken) | ||||||
|  | 
 | ||||||
|  | 		respJSON, err := requests.Request(req) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		email, err := respJSON.Get("email").String() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained an email") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		claims.Email = email | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return claims, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type OIDCClaims struct { | ||||||
|  | 	Subject  string `json:"sub"` | ||||||
|  | 	Email    string `json:"email"` | ||||||
|  | 	Verified *bool  `json:"email_verified"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,264 @@ | ||||||
|  | package providers | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"crypto/rsa" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"golang.org/x/oauth2" | ||||||
|  | 
 | ||||||
|  | 	"github.com/bmizerany/assert" | ||||||
|  | 	"github.com/coreos/go-oidc" | ||||||
|  | 	"github.com/dgrijalva/jwt-go" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 
 | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"net/url" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const accessToken = "access_token" | ||||||
|  | const refreshToken = "refresh_token" | ||||||
|  | const clientID = "https://test.myapp.com" | ||||||
|  | const secret = "secret" | ||||||
|  | 
 | ||||||
|  | type idTokenClaims struct { | ||||||
|  | 	Name    string `json:"name,omitempty"` | ||||||
|  | 	Email   string `json:"email,omitempty"` | ||||||
|  | 	Picture string `json:"picture,omitempty"` | ||||||
|  | 	jwt.StandardClaims | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type redeemTokenResponse struct { | ||||||
|  | 	AccessToken  string `json:"access_token"` | ||||||
|  | 	RefreshToken string `json:"refresh_token"` | ||||||
|  | 	ExpiresIn    int64  `json:"expires_in"` | ||||||
|  | 	TokenType    string `json:"token_type"` | ||||||
|  | 	IDToken      string `json:"id_token,omitempty"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var defaultIDToken idTokenClaims = idTokenClaims{ | ||||||
|  | 	"Jane Dobbs", | ||||||
|  | 	"janed@me.com", | ||||||
|  | 	"http://mugbook.com/janed/me.jpg", | ||||||
|  | 	jwt.StandardClaims{ | ||||||
|  | 		Audience:  "https://test.myapp.com", | ||||||
|  | 		ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), | ||||||
|  | 		Id:        "id-some-id", | ||||||
|  | 		IssuedAt:  time.Now().Unix(), | ||||||
|  | 		Issuer:    "https://issuer.example.com", | ||||||
|  | 		NotBefore: 0, | ||||||
|  | 		Subject:   "123456789", | ||||||
|  | 	}, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type fakeKeySetStub struct {} | ||||||
|  | 
 | ||||||
|  | func (fakeKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { | ||||||
|  | 	decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1]) | ||||||
|  | 	tokenClaims := &idTokenClaims{} | ||||||
|  | 	err = json.Unmarshal(decodeString, tokenClaims) | ||||||
|  | 
 | ||||||
|  | 	if err != nil || tokenClaims.Id == "this-id-fails-validation" { | ||||||
|  | 		return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return decodeString, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newOIDCProvider(serverURL *url.URL) *OIDCProvider { | ||||||
|  | 
 | ||||||
|  | 	providerData := &ProviderData{ | ||||||
|  | 		ProviderName: "oidc", | ||||||
|  | 		ClientID:     clientID, | ||||||
|  | 		ClientSecret: secret, | ||||||
|  | 		LoginURL: &url.URL{ | ||||||
|  | 			Scheme: serverURL.Scheme, | ||||||
|  | 			Host:   serverURL.Host, | ||||||
|  | 			Path:   "/login/oauth/authorize"}, | ||||||
|  | 		RedeemURL: &url.URL{ | ||||||
|  | 			Scheme: serverURL.Scheme, | ||||||
|  | 			Host:   serverURL.Host, | ||||||
|  | 			Path:   "/login/oauth/access_token"}, | ||||||
|  | 		ProfileURL: &url.URL{ | ||||||
|  | 			Scheme: serverURL.Scheme, | ||||||
|  | 			Host:   serverURL.Host, | ||||||
|  | 			Path:   "/profile"}, | ||||||
|  | 		ValidateURL: &url.URL{ | ||||||
|  | 			Scheme: serverURL.Scheme, | ||||||
|  | 			Host:   serverURL.Host, | ||||||
|  | 			Path:   "/api"}, | ||||||
|  | 		Scope: "openid profile offline_access"} | ||||||
|  | 
 | ||||||
|  | 	p := &OIDCProvider{ | ||||||
|  | 		ProviderData: providerData, | ||||||
|  | 		Verifier:     oidc.NewVerifier( | ||||||
|  | 			"https://issuer.example.com", | ||||||
|  | 			fakeKeySetStub{}, | ||||||
|  | 			&oidc.Config{ClientID: clientID}, | ||||||
|  | 		), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return p | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newOIDCServer(body []byte) (*url.URL, *httptest.Server) { | ||||||
|  | 	s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { | ||||||
|  | 		rw.Header().Add("content-type", "application/json") | ||||||
|  | 		_, _ = rw.Write(body) | ||||||
|  | 	})) | ||||||
|  | 	u, _ := url.Parse(s.URL) | ||||||
|  | 	return u, s | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) { | ||||||
|  | 
 | ||||||
|  | 	key, _ := rsa.GenerateKey(rand.Reader, 2048) | ||||||
|  | 	standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims) | ||||||
|  | 	return standardClaims.SignedString(key) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newOauth2Token() *oauth2.Token { | ||||||
|  | 	return &oauth2.Token{ | ||||||
|  | 		AccessToken:  accessToken, | ||||||
|  | 		TokenType:    "Bearer", | ||||||
|  | 		RefreshToken: refreshToken, | ||||||
|  | 		Expiry:       time.Time{}.Add(time.Duration(5) * time.Second), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) { | ||||||
|  | 	redeemURL, server := newOIDCServer(body) | ||||||
|  | 	provider := newOIDCProvider(redeemURL) | ||||||
|  | 	return server, provider | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestOIDCProviderRedeem(t *testing.T) { | ||||||
|  | 
 | ||||||
|  | 	idToken, _ := newSignedTestIDToken(defaultIDToken) | ||||||
|  | 	body, _ := json.Marshal(redeemTokenResponse{ | ||||||
|  | 		AccessToken:  accessToken, | ||||||
|  | 		ExpiresIn:    10, | ||||||
|  | 		TokenType:    "Bearer", | ||||||
|  | 		RefreshToken: refreshToken, | ||||||
|  | 		IDToken:      idToken, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	server, provider := newTestSetup(body) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	session, err := provider.Redeem(provider.RedeemURL.String(), "code1234") | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, defaultIDToken.Email, session.Email) | ||||||
|  | 	assert.Equal(t, accessToken, session.AccessToken) | ||||||
|  | 	assert.Equal(t, idToken, session.IDToken) | ||||||
|  | 	assert.Equal(t, refreshToken, session.RefreshToken) | ||||||
|  | 	assert.Equal(t, "123456789", session.User) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { | ||||||
|  | 
 | ||||||
|  | 	idToken, _ := newSignedTestIDToken(defaultIDToken) | ||||||
|  | 	body, _ := json.Marshal(redeemTokenResponse{ | ||||||
|  | 		AccessToken:  accessToken, | ||||||
|  | 		ExpiresIn:    10, | ||||||
|  | 		TokenType:    "Bearer", | ||||||
|  | 		RefreshToken: refreshToken, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	server, provider := newTestSetup(body) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	existingSession := &sessions.SessionState{ | ||||||
|  | 		AccessToken:  "changeit", | ||||||
|  | 		IDToken:      idToken, | ||||||
|  | 		CreatedAt:    time.Time{}, | ||||||
|  | 		ExpiresOn:    time.Time{}, | ||||||
|  | 		RefreshToken: refreshToken, | ||||||
|  | 		Email:        "janedoe@example.com", | ||||||
|  | 		User:         "11223344", | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	refreshed, err := provider.RefreshSessionIfNeeded(existingSession) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, refreshed, true) | ||||||
|  | 	assert.Equal(t, "janedoe@example.com", existingSession.Email) | ||||||
|  | 	assert.Equal(t, accessToken, existingSession.AccessToken) | ||||||
|  | 	assert.Equal(t, idToken, existingSession.IDToken) | ||||||
|  | 	assert.Equal(t, refreshToken, existingSession.RefreshToken) | ||||||
|  | 	assert.Equal(t, "11223344", existingSession.User) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { | ||||||
|  | 
 | ||||||
|  | 	idToken, _ := newSignedTestIDToken(defaultIDToken) | ||||||
|  | 	body, _ := json.Marshal(redeemTokenResponse{ | ||||||
|  | 		AccessToken:  accessToken, | ||||||
|  | 		ExpiresIn:    10, | ||||||
|  | 		TokenType:    "Bearer", | ||||||
|  | 		RefreshToken: refreshToken, | ||||||
|  | 		IDToken:      idToken, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	server, provider := newTestSetup(body) | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	existingSession := &sessions.SessionState{ | ||||||
|  | 		AccessToken:  "changeit", | ||||||
|  | 		IDToken:      "changeit", | ||||||
|  | 		CreatedAt:    time.Time{}, | ||||||
|  | 		ExpiresOn:    time.Time{}, | ||||||
|  | 		RefreshToken: refreshToken, | ||||||
|  | 		Email:        "changeit", | ||||||
|  | 		User:         "changeit", | ||||||
|  | 	} | ||||||
|  | 	refreshed, err := provider.RefreshSessionIfNeeded(existingSession) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, refreshed, true) | ||||||
|  | 	assert.Equal(t, defaultIDToken.Email, existingSession.Email) | ||||||
|  | 	assert.Equal(t, defaultIDToken.Subject, existingSession.User) | ||||||
|  | 	assert.Equal(t, accessToken, existingSession.AccessToken) | ||||||
|  | 	assert.Equal(t, idToken, existingSession.IDToken) | ||||||
|  | 	assert.Equal(t, refreshToken, existingSession.RefreshToken) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { | ||||||
|  | 
 | ||||||
|  | 	server, provider := newTestSetup([]byte("")) | ||||||
|  | 
 | ||||||
|  | 	defer server.Close() | ||||||
|  | 
 | ||||||
|  | 	token := newOauth2Token() | ||||||
|  | 	signedIdToken, _ := newSignedTestIDToken(defaultIDToken) | ||||||
|  | 	tokenWithIdToken := token.WithExtra(map[string]interface{}{ | ||||||
|  | 		"id_token": signedIdToken, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	verifiedIdToken, err := provider.findVerifiedIDToken(context.Background(), tokenWithIdToken) | ||||||
|  | 	assert.Equal(t, true, err == nil) | ||||||
|  | 	assert.Equal(t, true, verifiedIdToken != nil) | ||||||
|  | 	assert.Equal(t, defaultIDToken.Issuer, verifiedIdToken.Issuer) | ||||||
|  | 	assert.Equal(t, defaultIDToken.Subject, verifiedIdToken.Subject) | ||||||
|  | 
 | ||||||
|  | 	// When the validation fails the response should be nil
 | ||||||
|  | 	defaultIDToken.Id = "this-id-fails-validation" | ||||||
|  | 	signedIdToken, _ = newSignedTestIDToken(defaultIDToken) | ||||||
|  | 	tokenWithIdToken = token.WithExtra(map[string]interface{}{ | ||||||
|  | 		"id_token": signedIdToken, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	verifiedIdToken, err = provider.findVerifiedIDToken(context.Background(), tokenWithIdToken) | ||||||
|  | 	assert.Equal(t, errors.New("failed to verify signature: the validation failed for subject [123456789]"), err) | ||||||
|  | 	assert.Equal(t, true, verifiedIdToken == nil) | ||||||
|  | 
 | ||||||
|  | 	// When there is no id token in the oauth token
 | ||||||
|  | 	verifiedIdToken, err = provider.findVerifiedIDToken(context.Background(), newOauth2Token()) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, true, verifiedIdToken == nil) | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue