From 74918c40d8ba0975f84e101f5d39dc7378acdc51 Mon Sep 17 00:00:00 2001 From: Alexander Block Date: Tue, 15 Sep 2020 10:20:10 +0200 Subject: [PATCH] Refactor makeLoginURL to accept extraParams And don't require the caller to know how to use the returned params. --- providers/azure.go | 6 +++--- providers/logingov.go | 8 ++++---- providers/provider_default.go | 4 ++-- providers/util.go | 10 ++++++++-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/providers/azure.go b/providers/azure.go index 9103d178..934f4511 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -212,10 +212,10 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session } func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { - a, params := makeLoginURL(p.ProviderData, redirectURI, state) + extraParams := url.Values{} if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { - params.Add("resource", p.ProtectedResource.String()) + extraParams.Add("resource", p.ProtectedResource.String()) } - a.RawQuery = params.Encode() + a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) return a.String() } diff --git a/providers/logingov.go b/providers/logingov.go index 32fe1c78..a822108f 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -225,12 +225,12 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) // GetLoginURL overrides GetLoginURL to add login.gov parameters func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string { - a, params := makeLoginURL(p.ProviderData, redirectURI, state) + extraParams := url.Values{} if p.AcrValues == "" { acr := "http://idmanagement.gov/ns/assurance/loa/1" - params.Add("acr_values", acr) + extraParams.Add("acr_values", acr) } - params.Add("nonce", p.Nonce) - a.RawQuery = params.Encode() + extraParams.Add("nonce", p.Nonce) + a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams) return a.String() } diff --git a/providers/provider_default.go b/providers/provider_default.go index 5fc53219..337b284c 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -75,8 +75,8 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s // GetLoginURL with typical oauth parameters func (p *ProviderData) GetLoginURL(redirectURI, state string) string { - a, params := makeLoginURL(p, redirectURI, state) - a.RawQuery = params.Encode() + extraParams := url.Values{} + a := makeLoginURL(p, redirectURI, state, extraParams) return a.String() } diff --git a/providers/util.go b/providers/util.go index 5cbc7fb9..b4b65ac6 100644 --- a/providers/util.go +++ b/providers/util.go @@ -31,7 +31,7 @@ func makeOIDCHeader(accessToken string) http.Header { return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) } -func makeLoginURL(p *ProviderData, redirectURI, state string) (url.URL, url.Values) { +func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Values) url.URL { a := *p.LoginURL params, _ := url.ParseQuery(a.RawQuery) params.Set("redirect_uri", redirectURI) @@ -47,5 +47,11 @@ func makeLoginURL(p *ProviderData, redirectURI, state string) (url.URL, url.Valu params.Set("client_id", p.ClientID) params.Set("response_type", "code") params.Add("state", state) - return a, params + for n, p := range extraParams { + for _, v := range p { + params.Add(n, v) + } + } + a.RawQuery = params.Encode() + return a }