diff --git a/providers/azure.go b/providers/azure.go index c9940d61..9103d178 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -212,7 +212,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session } func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { - a, params := DefaultGetLoginURL(p.ProviderData, redirectURI, state) + a, params := makeLoginURL(p.ProviderData, redirectURI, state) if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { params.Add("resource", p.ProtectedResource.String()) } diff --git a/providers/logingov.go b/providers/logingov.go index e631237c..32fe1c78 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -225,7 +225,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) // GetLoginURL overrides GetLoginURL to add login.gov parameters func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string { - a, params := DefaultGetLoginURL(p.ProviderData, redirectURI, state) + a, params := makeLoginURL(p.ProviderData, redirectURI, state) if p.AcrValues == "" { acr := "http://idmanagement.gov/ns/assurance/loa/1" params.Add("acr_values", acr) diff --git a/providers/provider_default.go b/providers/provider_default.go index 65c7f729..5fc53219 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -73,28 +73,9 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s return } -func DefaultGetLoginURL(p *ProviderData, redirectURI, state string) (url.URL, url.Values) { - a := *p.LoginURL - params, _ := url.ParseQuery(a.RawQuery) - params.Set("redirect_uri", redirectURI) - if p.AcrValues != "" { - params.Add("acr_values", p.AcrValues) - } - if p.Prompt != "" { - params.Set("prompt", p.Prompt) - } else { // Legacy variant of the prompt param: - params.Set("approval_prompt", p.ApprovalPrompt) - } - params.Add("scope", p.Scope) - params.Set("client_id", p.ClientID) - params.Set("response_type", "code") - params.Add("state", state) - return a, params -} - // GetLoginURL with typical oauth parameters func (p *ProviderData) GetLoginURL(redirectURI, state string) string { - a, params := DefaultGetLoginURL(p, redirectURI, state) + a, params := makeLoginURL(p, redirectURI, state) a.RawQuery = params.Encode() return a.String() } diff --git a/providers/util.go b/providers/util.go index 374f637e..5cbc7fb9 100644 --- a/providers/util.go +++ b/providers/util.go @@ -3,6 +3,7 @@ package providers import ( "fmt" "net/http" + "net/url" ) const ( @@ -29,3 +30,22 @@ func makeOIDCHeader(accessToken string) http.Header { } return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) } + +func makeLoginURL(p *ProviderData, redirectURI, state string) (url.URL, url.Values) { + a := *p.LoginURL + params, _ := url.ParseQuery(a.RawQuery) + params.Set("redirect_uri", redirectURI) + if p.AcrValues != "" { + params.Add("acr_values", p.AcrValues) + } + if p.Prompt != "" { + params.Set("prompt", p.Prompt) + } else { // Legacy variant of the prompt param: + params.Set("approval_prompt", p.ApprovalPrompt) + } + params.Add("scope", p.Scope) + params.Set("client_id", p.ClientID) + params.Set("response_type", "code") + params.Add("state", state) + return a, params +}