diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e4ec5bf..3e1ac927 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,9 +12,14 @@ ## Breaking Changes - [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) When a Redis session store is configured, OAuth2-Proxy will fail to start up unless connection and health checks to Redis pass +- [#753](https://github.com/oauth2-proxy/oauth2-proxy/pull/753) A bug in the Azure provider prevented it from properly passing the configured protected `--resource` + via the login url. If this option was used in the past, behavior will change with this release as it will + affect the tokens returned by Azure. In the past, the tokens were always for `https://graph.microsoft.com` (the default) + and will now be for the configured resource (if it exists, otherwise it will run into errors) ## Changes since v6.1.1 +- [#753](https://github.com/oauth2-proxy/oauth2-proxy/pull/753) Pass resource parameter in login url (@codablock) - [#575](https://github.com/oauth2-proxy/oauth2-proxy/pull/575) Stop accepting legacy SHA1 signed cookies (@NickMeves) - [#722](https://github.com/oauth2-proxy/oauth2-proxy/pull/722) Validate Redis configuration options at startup (@NickMeves) - [#791](https://github.com/oauth2-proxy/oauth2-proxy/pull/791) Remove GetPreferredUsername method from provider interface (@NickMeves) diff --git a/providers/azure.go b/providers/azure.go index 0ae0cba6..934f4511 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -210,3 +210,12 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session return email, err } + +func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { + extraParams := url.Values{} + if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { + extraParams.Add("resource", p.ProtectedResource.String()) + } + a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) + return a.String() +} diff --git a/providers/azure_test.go b/providers/azure_test.go index fe9bbb42..6e2e4e97 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -213,3 +213,10 @@ func TestAzureProviderRedeemReturnsIdToken(t *testing.T) { assert.Equal(t, timestamp, s.ExpiresOn.UTC()) assert.Equal(t, "refresh1234", s.RefreshToken) } + +func TestAzureProviderProtectedResourceConfigured(t *testing.T) { + p := testAzureProvider("") + p.ProtectedResource, _ = url.Parse("http://my.resource.test") + result := p.GetLoginURL("https://my.test.app/oauth", "") + assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) +} diff --git a/providers/logingov.go b/providers/logingov.go index c524741f..a822108f 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -225,20 +225,12 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) // GetLoginURL overrides GetLoginURL to add login.gov parameters func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string { - a := *p.LoginURL - params, _ := url.ParseQuery(a.RawQuery) - params.Set("redirect_uri", redirectURI) - params.Set("approval_prompt", p.ApprovalPrompt) - params.Add("scope", p.Scope) - params.Set("client_id", p.ClientID) - params.Set("response_type", "code") - params.Add("state", state) - acr := p.AcrValues - if acr == "" { - acr = "http://idmanagement.gov/ns/assurance/loa/1" + extraParams := url.Values{} + if p.AcrValues == "" { + acr := "http://idmanagement.gov/ns/assurance/loa/1" + extraParams.Add("acr_values", acr) } - params.Add("acr_values", acr) - params.Add("nonce", p.Nonce) - a.RawQuery = params.Encode() + extraParams.Add("nonce", p.Nonce) + a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) return a.String() } diff --git a/providers/logingov_test.go b/providers/logingov_test.go index 2c0f8357..0b70190b 100644 --- a/providers/logingov_test.go +++ b/providers/logingov_test.go @@ -289,3 +289,10 @@ func TestLoginGovProviderBadNonce(t *testing.T) { // The "badfakenonce" in the idtoken above should cause this to error out assert.Error(t, err) } + +func TestLoginGovProviderGetLoginURL(t *testing.T) { + p, _, _ := newLoginGovProvider() + result := p.GetLoginURL("http://redirect/", "") + assert.Contains(t, result, "acr_values="+url.QueryEscape("http://idmanagement.gov/ns/assurance/loa/1")) + assert.Contains(t, result, "nonce=fakenonce") +} diff --git a/providers/provider_default.go b/providers/provider_default.go index ba05a96c..337b284c 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -75,22 +75,8 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s // GetLoginURL with typical oauth parameters func (p *ProviderData) GetLoginURL(redirectURI, state string) string { - a := *p.LoginURL - params, _ := url.ParseQuery(a.RawQuery) - params.Set("redirect_uri", redirectURI) - if p.AcrValues != "" { - params.Add("acr_values", p.AcrValues) - } - if p.Prompt != "" { - params.Set("prompt", p.Prompt) - } else { // Legacy variant of the prompt param: - params.Set("approval_prompt", p.ApprovalPrompt) - } - params.Add("scope", p.Scope) - params.Set("client_id", p.ClientID) - params.Set("response_type", "code") - params.Add("state", state) - a.RawQuery = params.Encode() + extraParams := url.Values{} + a := makeLoginURL(p, redirectURI, state, extraParams) return a.String() } diff --git a/providers/util.go b/providers/util.go index 374f637e..b4b65ac6 100644 --- a/providers/util.go +++ b/providers/util.go @@ -3,6 +3,7 @@ package providers import ( "fmt" "net/http" + "net/url" ) const ( @@ -29,3 +30,28 @@ func makeOIDCHeader(accessToken string) http.Header { } return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) } + +func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Values) url.URL { + a := *p.LoginURL + params, _ := url.ParseQuery(a.RawQuery) + params.Set("redirect_uri", redirectURI) + if p.AcrValues != "" { + params.Add("acr_values", p.AcrValues) + } + if p.Prompt != "" { + params.Set("prompt", p.Prompt) + } else { // Legacy variant of the prompt param: + params.Set("approval_prompt", p.ApprovalPrompt) + } + params.Add("scope", p.Scope) + params.Set("client_id", p.ClientID) + params.Set("response_type", "code") + params.Add("state", state) + for n, p := range extraParams { + for _, v := range p { + params.Add(n, v) + } + } + a.RawQuery = params.Encode() + return a +}