Azure token refresh (#754)
* Implement azure token refresh Based on original PR https://github.com/oauth2-proxy/oauth2-proxy/pull/278 * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> * Set CreatedAt to Now() on token refresh Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
This commit is contained in:
		
							parent
							
								
									65016c8da1
								
							
						
					
					
						commit
						0e119d7c84
					
				|  | @ -23,9 +23,16 @@ | ||||||
|   via the login url. If this option was used in the past, behavior will change with this release as it will |   via the login url. If this option was used in the past, behavior will change with this release as it will | ||||||
|   affect the tokens returned by Azure. In the past, the tokens were always for `https://graph.microsoft.com` (the default) |   affect the tokens returned by Azure. In the past, the tokens were always for `https://graph.microsoft.com` (the default) | ||||||
|   and will now be for the configured resource (if it exists, otherwise it will run into errors) |   and will now be for the configured resource (if it exists, otherwise it will run into errors) | ||||||
|  | - [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) The Azure provider now has token refresh functionality implemented. This means that there won't | ||||||
|  |   be any redirects in the browser anymore when tokens expire, but instead a token refresh is initiated | ||||||
|  |   in the background, which leads to new tokens being returned in the cookies. | ||||||
|  |   - Please note that `--cookie-refresh` must be 0 (the default) or equal to the token lifespan configured in Azure AD to make | ||||||
|  |   Azure token refresh reliable. Setting this value to 0 means that it relies on the provider implementation | ||||||
|  |   to decide if a refresh is required. | ||||||
| 
 | 
 | ||||||
| ## Changes since v6.1.1 | ## Changes since v6.1.1 | ||||||
| 
 | 
 | ||||||
|  | - [#754](https://github.com/oauth2-proxy/oauth2-proxy/pull/754) Azure token refresh (@codablock) | ||||||
| - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) | - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) | ||||||
| - [#796](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves) | - [#796](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves) | ||||||
| - [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed) | - [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed) | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -74,6 +75,9 @@ func NewAzureProvider(p *ProviderData) *AzureProvider { | ||||||
| 	if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { | 	if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { | ||||||
| 		p.ProtectedResource = azureDefaultProtectResourceURL | 		p.ProtectedResource = azureDefaultProtectResourceURL | ||||||
| 	} | 	} | ||||||
|  | 	if p.ValidateURL == nil || p.ValidateURL.String() == "" { | ||||||
|  | 		p.ValidateURL = p.ProfileURL | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	return &AzureProvider{ | 	return &AzureProvider{ | ||||||
| 		ProviderData: p, | 		ProviderData: p, | ||||||
|  | @ -103,6 +107,7 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		err = errors.New("missing code") | ||||||
|  | @ -123,6 +128,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 		params.Add("resource", p.ProtectedResource.String()) | 		params.Add("resource", p.ProtectedResource.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// blindly try json and x-www-form-urlencoded
 | ||||||
| 	var jsonResponse struct { | 	var jsonResponse struct { | ||||||
| 		AccessToken  string `json:"access_token"` | 		AccessToken  string `json:"access_token"` | ||||||
| 		RefreshToken string `json:"refresh_token"` | 		RefreshToken string `json:"refresh_token"` | ||||||
|  | @ -151,6 +157,61 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 		RefreshToken: jsonResponse.RefreshToken, | 		RefreshToken: jsonResponse.RefreshToken, | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
|  | // RefreshToken to fetch a new ID token if required
 | ||||||
|  | func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||||
|  | 		return false, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	origExpiration := s.ExpiresOn | ||||||
|  | 
 | ||||||
|  | 	err := p.redeemRefreshToken(ctx, s) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) | ||||||
|  | 	return true, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { | ||||||
|  | 	params := url.Values{} | ||||||
|  | 	params.Add("client_id", p.ClientID) | ||||||
|  | 	params.Add("client_secret", p.ClientSecret) | ||||||
|  | 	params.Add("refresh_token", s.RefreshToken) | ||||||
|  | 	params.Add("grant_type", "refresh_token") | ||||||
|  | 
 | ||||||
|  | 	var jsonResponse struct { | ||||||
|  | 		AccessToken  string `json:"access_token"` | ||||||
|  | 		RefreshToken string `json:"refresh_token"` | ||||||
|  | 		ExpiresOn    int64  `json:"expires_on,string"` | ||||||
|  | 		IDToken      string `json:"id_token"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = requests.New(p.RedeemURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithMethod("POST"). | ||||||
|  | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
|  | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&jsonResponse) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	now := time.Now() | ||||||
|  | 	expires := time.Unix(jsonResponse.ExpiresOn, 0) | ||||||
|  | 	s.AccessToken = jsonResponse.AccessToken | ||||||
|  | 	s.IDToken = jsonResponse.IDToken | ||||||
|  | 	s.RefreshToken = jsonResponse.RefreshToken | ||||||
|  | 	s.CreatedAt = &now | ||||||
|  | 	s.ExpiresOn = &expires | ||||||
|  | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func makeAzureHeader(accessToken string) http.Header { | func makeAzureHeader(accessToken string) http.Header { | ||||||
|  | @ -219,3 +280,8 @@ func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { | ||||||
| 	a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) | 	a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) | ||||||
| 	return a.String() | 	return a.String() | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // ValidateSessionState validates the AccessToken
 | ||||||
|  | func (p *AzureProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||||
|  | 	return validateToken(ctx, p, s.AccessToken, makeAzureHeader(s.AccessToken)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -8,6 +8,8 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | 
 | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
|  | @ -42,7 +44,7 @@ func TestNewAzureProvider(t *testing.T) { | ||||||
| 	g.Expect(providerData.LoginURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/authorize")) | 	g.Expect(providerData.LoginURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/authorize")) | ||||||
| 	g.Expect(providerData.RedeemURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/token")) | 	g.Expect(providerData.RedeemURL.String()).To(Equal("https://login.microsoftonline.com/common/oauth2/token")) | ||||||
| 	g.Expect(providerData.ProfileURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me")) | 	g.Expect(providerData.ProfileURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me")) | ||||||
| 	g.Expect(providerData.ValidateURL.String()).To(Equal("")) | 	g.Expect(providerData.ValidateURL.String()).To(Equal("https://graph.microsoft.com/v1.0/me")) | ||||||
| 	g.Expect(providerData.Scope).To(Equal("openid")) | 	g.Expect(providerData.Scope).To(Equal("openid")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -97,7 +99,7 @@ func TestAzureSetTenant(t *testing.T) { | ||||||
| 		p.Data().ProfileURL.String()) | 		p.Data().ProfileURL.String()) | ||||||
| 	assert.Equal(t, "https://graph.microsoft.com", | 	assert.Equal(t, "https://graph.microsoft.com", | ||||||
| 		p.Data().ProtectedResource.String()) | 		p.Data().ProtectedResource.String()) | ||||||
| 	assert.Equal(t, "", p.Data().ValidateURL.String()) | 	assert.Equal(t, "https://graph.microsoft.com/v1.0/me", p.Data().ValidateURL.String()) | ||||||
| 	assert.Equal(t, "openid", p.Data().Scope) | 	assert.Equal(t, "openid", p.Data().Scope) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -220,3 +222,47 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { | ||||||
| 	result := p.GetLoginURL("https://my.test.app/oauth", "") | 	result := p.GetLoginURL("https://my.test.app/oauth", "") | ||||||
| 	assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) | 	assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestAzureProviderGetsTokensInRedeem(t *testing.T) { | ||||||
|  | 	b := testAzureBackend(`{ "access_token": "some_access_token", "refresh_token": "some_refresh_token", "expires_on": "1136239445", "id_token": "some_id_token" }`) | ||||||
|  | 	defer b.Close() | ||||||
|  | 	timestamp, _ := time.Parse(time.RFC3339, "2006-01-02T22:04:05Z") | ||||||
|  | 	bURL, _ := url.Parse(b.URL) | ||||||
|  | 	p := testAzureProvider(bURL.Host) | ||||||
|  | 
 | ||||||
|  | 	session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.NotEqual(t, session, nil) | ||||||
|  | 	assert.Equal(t, "some_access_token", session.AccessToken) | ||||||
|  | 	assert.Equal(t, "some_refresh_token", session.RefreshToken) | ||||||
|  | 	assert.Equal(t, "some_id_token", session.IDToken) | ||||||
|  | 	assert.Equal(t, timestamp, session.ExpiresOn.UTC()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { | ||||||
|  | 	p := testAzureProvider("") | ||||||
|  | 
 | ||||||
|  | 	expires := time.Now().Add(time.Duration(1) * time.Hour) | ||||||
|  | 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | ||||||
|  | 	refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.False(t, refreshNeeded) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestAzureProviderRefreshWhenExpired(t *testing.T) { | ||||||
|  | 	b := testAzureBackend(`{ "access_token": "new_some_access_token", "refresh_token": "new_some_refresh_token", "expires_on": "32693148245", "id_token": "new_some_id_token" }`) | ||||||
|  | 	defer b.Close() | ||||||
|  | 	timestamp, _ := time.Parse(time.RFC3339, "3006-01-02T22:04:05Z") | ||||||
|  | 	bURL, _ := url.Parse(b.URL) | ||||||
|  | 	p := testAzureProvider(bURL.Host) | ||||||
|  | 
 | ||||||
|  | 	expires := time.Now().Add(time.Duration(-1) * time.Hour) | ||||||
|  | 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | ||||||
|  | 	_, err := p.RefreshSessionIfNeeded(context.Background(), session) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.NotEqual(t, session, nil) | ||||||
|  | 	assert.Equal(t, "new_some_access_token", session.AccessToken) | ||||||
|  | 	assert.Equal(t, "new_some_refresh_token", session.RefreshToken) | ||||||
|  | 	assert.Equal(t, "new_some_id_token", session.IDToken) | ||||||
|  | 	assert.Equal(t, timestamp, session.ExpiresOn.UTC()) | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue