From de9e65a63adf43778a5478c9e331a8fcd71b7f56 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Mon, 6 Jul 2020 17:42:26 +0100 Subject: [PATCH] Migrate all requests to result pattern --- pkg/validation/options.go | 5 ++--- providers/azure.go | 2 ++ providers/bitbucket.go | 3 +++ providers/digitalocean.go | 1 + providers/facebook.go | 1 + providers/github.go | 34 ++++++++++++++++------------------ providers/gitlab.go | 1 + providers/google.go | 2 ++ providers/internal_util.go | 15 ++++++--------- providers/keycloak.go | 1 + providers/linkedin.go | 1 + providers/logingov.go | 2 ++ providers/nextcloud.go | 1 + providers/oidc.go | 1 + providers/provider_default.go | 26 ++++++-------------------- 15 files changed, 46 insertions(+), 50 deletions(-) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index ae2ed065..301c5e90 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -86,6 +86,7 @@ func Validate(o *options.Options) error { requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" body, err := requests.New(requestURL). WithContext(ctx). + Do(). UnmarshalJSON() if err != nil { logger.Printf("error: failed to discover OIDC configuration: %v", err) @@ -384,11 +385,9 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error if err != nil { // Try as JWKS URI jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" - resp, err := requests.New(jwksURI).Do() - if err != nil { + if err := requests.New(jwksURI).Do().Error(); err != nil { return nil, err } - resp.Body.Close() verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) } else { diff --git a/providers/azure.go b/providers/azure.go index 5e8df0a0..b38c1cc7 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -101,6 +101,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). UnmarshalInto(&jsonResponse) if err != nil { return nil, err @@ -153,6 +154,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session json, err := requests.New(p.ProfileURL.String()). WithContext(ctx). WithHeaders(getAzureHeader(s.AccessToken)). + Do(). UnmarshalJSON() if err != nil { return "", err diff --git a/providers/bitbucket.go b/providers/bitbucket.go index 726d3562..ffc52c79 100644 --- a/providers/bitbucket.go +++ b/providers/bitbucket.go @@ -88,6 +88,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken err := requests.New(requestURL). WithContext(ctx). + Do(). UnmarshalInto(&emails) if err != nil { logger.Printf("failed making request: %v", err) @@ -103,6 +104,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses err := requests.New(requestURL). WithContext(ctx). + Do(). UnmarshalInto(&teams) if err != nil { logger.Printf("failed requesting teams membership: %v", err) @@ -132,6 +134,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses err := requests.New(requestURL). WithContext(ctx). + Do(). UnmarshalInto(&repositories) if err != nil { logger.Printf("failed checking repository access: %v", err) diff --git a/providers/digitalocean.go b/providers/digitalocean.go index 306646d5..27ac60d0 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -64,6 +64,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. json, err := requests.New(p.ProfileURL.String()). WithContext(ctx). WithHeaders(getDigitalOceanHeader(s.AccessToken)). + Do(). UnmarshalJSON() if err != nil { return "", err diff --git a/providers/facebook.go b/providers/facebook.go index 81973416..d3d123f2 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -72,6 +72,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess err := requests.New(requestURL). WithContext(ctx). WithHeaders(getFacebookHeader(s.AccessToken)). + Do(). UnmarshalInto(&r) if err != nil { return "", err diff --git a/providers/github.go b/providers/github.go index e93a23e1..014ae3cb 100644 --- a/providers/github.go +++ b/providers/github.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io/ioutil" "net/http" "net/url" "path" @@ -116,6 +115,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). + Do(). UnmarshalInto(&op) if err != nil { return false, err @@ -179,12 +179,12 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) // bodyclose cannot detect that the body is being closed later in requests.Into, // so have to skip the linting for the next line. // nolint:bodyclose - resp, err := requests.New(endpoint.String()). + result := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). Do() - if err != nil { - return false, err + if result.Error() != nil { + return false, result.Error() } if last == 0 { @@ -200,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) // link header at last page (doesn't exist last info) // ; rel="prev", ; rel="first" - link := resp.Header.Get("Link") + link := result.Headers().Get("Link") rep1 := regexp.MustCompile(`(?s).*\; rel="last".*`) i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) @@ -211,7 +211,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) } var tp teamsPage - if err := requests.UnmarshalInto(resp, &tp); err != nil { + if err := result.UnmarshalInto(&tp); err != nil { return false, err } if len(tp) == 0 { @@ -282,6 +282,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). + Do(). UnmarshalInto(&repo) if err != nil { return false, err @@ -309,6 +310,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). + Do(). UnmarshalInto(&user) if err != nil { return false, err @@ -328,26 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), } - resp, err := requests.New(endpoint.String()). + result := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(accessToken)). Do() - if err != nil { - return false, err + if result.Error() != nil { + return false, result.Error() } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - - if resp.StatusCode != 204 { + if result.StatusCode() != 204 { return false, fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) + result.StatusCode(), endpoint.String(), result.Body()) } - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) + logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body()) return true, nil } @@ -401,6 +397,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(s.AccessToken)). + Do(). UnmarshalInto(&emails) if err != nil { return "", err @@ -435,6 +432,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta err := requests.New(endpoint.String()). WithContext(ctx). WithHeaders(getGitHubHeader(s.AccessToken)). + Do(). UnmarshalInto(&user) if err != nil { return "", err diff --git a/providers/gitlab.go b/providers/gitlab.go index 17c06f5e..8c1e1534 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -133,6 +133,7 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta err := requests.New(userInfoURL.String()). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). + Do(). UnmarshalInto(&userInfo) if err != nil { return nil, fmt.Errorf("error getting user info: %v", err) diff --git a/providers/google.go b/providers/google.go index fbed12f1..af2eebf3 100644 --- a/providers/google.go +++ b/providers/google.go @@ -129,6 +129,7 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). UnmarshalInto(&jsonResponse) if err != nil { return nil, err @@ -280,6 +281,7 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). UnmarshalInto(&data) if err != nil { return "", "", 0, err diff --git a/providers/internal_util.go b/providers/internal_util.go index 361e27c7..42948408 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -2,7 +2,6 @@ package providers import ( "context" - "io/ioutil" "net/http" "net/url" @@ -57,23 +56,21 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h endpoint = endpoint + "?" + params.Encode() } - resp, err := requests.New(endpoint). + result := requests.New(endpoint). WithContext(ctx). WithHeaders(header). Do() - if err != nil { + if result.Error() != nil { logger.Printf("GET %s", stripToken(endpoint)) - logger.Printf("token validation request failed: %s", err) + logger.Printf("token validation request failed: %s", result.Error()) return false } - body, _ := ioutil.ReadAll(resp.Body) - resp.Body.Close() - logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) + logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body()) - if resp.StatusCode == 200 { + if result.StatusCode() == 200 { return true } - logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) + logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()) return false } diff --git a/providers/keycloak.go b/providers/keycloak.go index 78206de8..77efc0c7 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -53,6 +53,7 @@ func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess json, err := requests.New(p.ValidateURL.String()). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). + Do(). UnmarshalJSON() if err != nil { logger.Printf("failed making request %s", err) diff --git a/providers/linkedin.go b/providers/linkedin.go index 8dcd3c9d..7328dbbb 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -63,6 +63,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess json, err := requests.New(requestURL). WithContext(ctx). WithHeaders(getLinkedInHeader(s.AccessToken)). + Do(). UnmarshalJSON() if err != nil { return "", err diff --git a/providers/logingov.go b/providers/logingov.go index 8846a8f2..79eb1827 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -141,6 +141,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint err := requests.New(userInfoEndpoint). WithContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). + Do(). UnmarshalInto(&emailData) if err != nil { return "", err @@ -196,6 +197,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do(). UnmarshalInto(&jsonResponse) if err != nil { return nil, err diff --git a/providers/nextcloud.go b/providers/nextcloud.go index 844bbff3..b70fd07c 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -33,6 +33,7 @@ func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses json, err := requests.New(p.ValidateURL.String()). WithContext(ctx). WithHeaders(getNextcloudHeader(s.AccessToken)). + Do(). UnmarshalJSON() if err != nil { return "", fmt.Errorf("error making request: %v", err) diff --git a/providers/oidc.go b/providers/oidc.go index 08e0c3e8..e456db76 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -259,6 +259,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. respJSON, err := requests.New(profileURL). WithContext(ctx). WithHeaders(getOIDCHeader(accessToken)). + Do(). UnmarshalJSON() if err != nil { return nil, err diff --git a/providers/provider_default.go b/providers/provider_default.go index f29e8c4a..598b91e8 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -3,10 +3,8 @@ package providers import ( "bytes" "context" - "encoding/json" "errors" "fmt" - "io/ioutil" "net/url" "time" @@ -39,33 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } - resp, err := requests.New(p.RedeemURL.String()). + result := requests.New(p.RedeemURL.String()). WithContext(ctx). WithMethod("POST"). WithBody(bytes.NewBufferString(params.Encode())). SetHeader("Content-Type", "application/x-www-form-urlencoded"). Do() - if err != nil { - return nil, err - } - - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return + if result.Error() != nil { + return nil, result.Error() } // blindly try json and x-www-form-urlencoded var jsonResponse struct { AccessToken string `json:"access_token"` } - err = json.Unmarshal(body, &jsonResponse) + err = result.UnmarshalInto(&jsonResponse) if err == nil { s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, @@ -74,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s } var v url.Values - v, err = url.ParseQuery(string(body)) + v, err = url.ParseQuery(string(result.Body())) if err != nil { return } @@ -82,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s created := time.Now() s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} } else { - err = fmt.Errorf("no access token found %s", body) + err = fmt.Errorf("no access token found %s", result.Body()) } return }