From 53142455b6ed997ee10f45e05db63cb52c9f5311 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Fri, 3 Jul 2020 19:27:25 +0100 Subject: [PATCH] Migrate all requests to new builder pattern --- pkg/validation/options.go | 57 ++++++------- providers/azure.go | 49 +++-------- providers/bitbucket.go | 51 +++++------- providers/digitalocean.go | 10 +-- providers/facebook.go | 13 +-- providers/github.go | 153 +++++++++------------------------- providers/gitlab.go | 32 ++----- providers/google.go | 68 +++++---------- providers/internal_util.go | 6 +- providers/keycloak.go | 13 +-- providers/linkedin.go | 11 ++- providers/logingov.go | 84 ++++++------------- providers/nextcloud.go | 17 ++-- providers/oidc.go | 11 +-- providers/provider_default.go | 18 ++-- 15 files changed, 194 insertions(+), 399 deletions(-) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 50717b60..ae2ed065 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -83,34 +83,33 @@ func Validate(o *options.Options) error { logger.Printf("Performing OIDC Discovery...") - if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil { - if body, err := requests.Request(req); err == nil { - - // Prefer manually configured URLs. It's a bit unclear - // why you'd be doing discovery and also providing the URLs - // explicitly though... - if o.LoginURL == "" { - o.LoginURL = body.Get("authorization_endpoint").MustString() - } - - if o.RedeemURL == "" { - o.RedeemURL = body.Get("token_endpoint").MustString() - } - - if o.OIDCJwksURL == "" { - o.OIDCJwksURL = body.Get("jwks_uri").MustString() - } - - if o.ProfileURL == "" { - o.ProfileURL = body.Get("userinfo_endpoint").MustString() - } - - o.SkipOIDCDiscovery = true - } else { - logger.Printf("error: failed to discover OIDC configuration: %v", err) - } + requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" + body, err := requests.New(requestURL). + WithContext(ctx). + UnmarshalJSON() + if err != nil { + logger.Printf("error: failed to discover OIDC configuration: %v", err) } else { - logger.Printf("error: failed parsing OIDC discovery URL: %v", err) + // Prefer manually configured URLs. It's a bit unclear + // why you'd be doing discovery and also providing the URLs + // explicitly though... + if o.LoginURL == "" { + o.LoginURL = body.Get("authorization_endpoint").MustString() + } + + if o.RedeemURL == "" { + o.RedeemURL = body.Get("token_endpoint").MustString() + } + + if o.OIDCJwksURL == "" { + o.OIDCJwksURL = body.Get("jwks_uri").MustString() + } + + if o.ProfileURL == "" { + o.ProfileURL = body.Get("userinfo_endpoint").MustString() + } + + o.SkipOIDCDiscovery = true } } @@ -385,10 +384,12 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error if err != nil { // Try as JWKS URI jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" - _, err := http.NewRequest("GET", jwksURI, nil) + resp, err := requests.New(jwksURI).Do() if err != nil { return nil, err } + resp.Body.Close() + verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) } else { verifier = provider.Verifier(config) diff --git a/providers/azure.go b/providers/azure.go index aea1b0e5..5e8df0a0 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -3,10 +3,8 @@ package providers import ( "bytes" "context" - "encoding/json" "errors" "fmt" - "io/ioutil" "net/http" "net/url" "time" @@ -91,39 +89,21 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } - var jsonResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresOn int64 `json:"expires_on,string"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &jsonResponse) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } created := time.Now() @@ -169,26 +149,21 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) - if err != nil { - return "", err - } - req.Header = getAzureHeader(s.AccessToken) - - json, err := requests.Request(req) + json, err := requests.New(p.ProfileURL.String()). + WithContext(ctx). + WithHeaders(getAzureHeader(s.AccessToken)). + UnmarshalJSON() if err != nil { return "", err } email, err = getEmailFromJSON(json) - if err == nil && email != "" { return email, err } email, err = json.Get("userPrincipalName").String() - if err != nil { logger.Printf("failed making request %s", err) return "", err diff --git a/providers/bitbucket.go b/providers/bitbucket.go index 2bb876cb..726d3562 100644 --- a/providers/bitbucket.go +++ b/providers/bitbucket.go @@ -2,7 +2,6 @@ package providers import ( "context" - "net/http" "net/url" "strings" @@ -85,15 +84,13 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses FullName string `json:"full_name"` } } - req, err := http.NewRequestWithContext(ctx, "GET", - p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) + + requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken + err := requests.New(requestURL). + WithContext(ctx). + UnmarshalInto(&emails) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &emails) - if err != nil { - logger.Printf("failed making request %s", err) + logger.Printf("failed making request: %v", err) return "", err } @@ -101,15 +98,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses teamURL := &url.URL{} *teamURL = *p.ValidateURL teamURL.Path = "/2.0/teams" - req, err = http.NewRequestWithContext(ctx, "GET", - teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) + + requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken + + err := requests.New(requestURL). + WithContext(ctx). + UnmarshalInto(&teams) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &teams) - if err != nil { - logger.Printf("failed requesting teams membership %s", err) + logger.Printf("failed requesting teams membership: %v", err) return "", err } var found = false @@ -129,20 +125,19 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses repositoriesURL := &url.URL{} *repositoriesURL = *p.ValidateURL repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] - req, err = http.NewRequestWithContext(ctx, "GET", - repositoriesURL.String()+"?role=contributor"+ - "&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+ - "&access_token="+s.AccessToken, - nil) + + requestURL := repositoriesURL.String() + "?role=contributor" + + "&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") + + "&access_token=" + s.AccessToken + + err := requests.New(requestURL). + WithContext(ctx). + UnmarshalInto(&repositories) if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - err = requests.RequestJSON(req, &repositories) - if err != nil { - logger.Printf("failed checking repository access %s", err) + logger.Printf("failed checking repository access: %v", err) return "", err } + var found = false for _, repository := range repositories.Values { if p.Repository == repository.FullName { diff --git a/providers/digitalocean.go b/providers/digitalocean.go index 25d37af9..306646d5 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -60,13 +60,11 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) - if err != nil { - return "", err - } - req.Header = getDigitalOceanHeader(s.AccessToken) - json, err := requests.Request(req) + json, err := requests.New(p.ProfileURL.String()). + WithContext(ctx). + WithHeaders(getDigitalOceanHeader(s.AccessToken)). + UnmarshalJSON() if err != nil { return "", err } diff --git a/providers/facebook.go b/providers/facebook.go index 0f9cc624..81973416 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -62,20 +62,21 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil) - if err != nil { - return "", err - } - req.Header = getFacebookHeader(s.AccessToken) type result struct { Email string } var r result - err = requests.RequestJSON(req, &r) + + requestURL := p.ProfileURL.String() + "?fields=name,email" + err := requests.New(requestURL). + WithContext(ctx). + WithHeaders(getFacebookHeader(s.AccessToken)). + UnmarshalInto(&r) if err != nil { return "", err } + if r.Email == "" { return "", errors.New("no email") } diff --git a/providers/github.go b/providers/github.go index 6d3f8b02..e93a23e1 100644 --- a/providers/github.go +++ b/providers/github.go @@ -2,7 +2,6 @@ package providers import ( "context" - "encoding/json" "errors" "fmt" "io/ioutil" @@ -15,6 +14,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) // GitHubProvider represents an GitHub based Identity Provider @@ -111,27 +111,16 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, Path: path.Join(p.ValidateURL.Path, "/user/orgs"), RawQuery: params.Encode(), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } var op orgsPage - if err := json.Unmarshal(body, &op); err != nil { + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + UnmarshalInto(&op) + if err != nil { return false, err } + if len(op) == 0 { break } @@ -187,9 +176,13 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) RawQuery: params.Encode(), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) + // bodyclose cannot detect that the body is being closed later in requests.Into, + // so have to skip the linting for the next line. + // nolint:bodyclose + resp, err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do() if err != nil { return false, err } @@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) } } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - resp.Body.Close() - return false, err - } - resp.Body.Close() - - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } - var tp teamsPage - if err := json.Unmarshal(body, &tp); err != nil { - return false, fmt.Errorf("%s unmarshaling %s", err, body) + if err := requests.UnmarshalInto(resp, &tp); err != nil { + return false, err } if len(tp) == 0 { break @@ -297,25 +278,12 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } - var repo repository - if err := json.Unmarshal(body, &repo); err != nil { + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + UnmarshalInto(&repo) + if err != nil { return false, err } @@ -337,26 +305,14 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user"), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + UnmarshalInto(&user) if err != nil { return false, err } - if resp.StatusCode != 200 { - return false, fmt.Errorf("got %d from %q %s", - resp.StatusCode, stripToken(endpoint.String()), body) - } - - if err := json.Unmarshal(body, &user); err != nil { - return false, err - } if p.isVerifiedUser(user.Login) { return true, nil @@ -372,12 +328,14 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(accessToken) - resp, err := http.DefaultClient.Do(req) + resp, err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(accessToken)). + Do() if err != nil { return false, err } + body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { @@ -440,28 +398,13 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user/emails"), } - req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - req.Header = getGitHubHeader(s.AccessToken) - resp, err := http.DefaultClient.Do(req) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(s.AccessToken)). + UnmarshalInto(&emails) if err != nil { return "", err } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return "", err - } - - if resp.StatusCode != 200 { - return "", fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) - } - - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) - - if err := json.Unmarshal(body, &emails); err != nil { - return "", fmt.Errorf("%s unmarshaling %s", err, body) - } returnEmail := "" for _, email := range emails { @@ -489,34 +432,14 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta Path: path.Join(p.ValidateURL.Path, "/user"), } - req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) - if err != nil { - return "", fmt.Errorf("could not create new GET request: %v", err) - } - - req.Header = getGitHubHeader(s.AccessToken) - resp, err := http.DefaultClient.Do(req) + err := requests.New(endpoint.String()). + WithContext(ctx). + WithHeaders(getGitHubHeader(s.AccessToken)). + UnmarshalInto(&user) if err != nil { return "", err } - body, err := ioutil.ReadAll(resp.Body) - defer resp.Body.Close() - if err != nil { - return "", err - } - - if resp.StatusCode != 200 { - return "", fmt.Errorf("got %d from %q %s", - resp.StatusCode, endpoint.String(), body) - } - - logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) - - if err := json.Unmarshal(body, &user); err != nil { - return "", fmt.Errorf("%s unmarshaling %s", err, body) - } - // Now that we have the username we can check collaborator status if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" { if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok { diff --git a/providers/gitlab.go b/providers/gitlab.go index 8d959781..17c06f5e 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -2,15 +2,13 @@ package providers import ( "context" - "encoding/json" "fmt" - "io/ioutil" - "net/http" "strings" "time" oidc "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "golang.org/x/oauth2" ) @@ -131,31 +129,13 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta userInfoURL := *p.LoginURL userInfoURL.Path = "/oauth/userinfo" - req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil) - if err != nil { - return nil, fmt.Errorf("failed to create user info request: %v", err) - } - req.Header.Set("Authorization", "Bearer "+s.AccessToken) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to perform user info request: %v", err) - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("failed to read user info response: %v", err) - } - - if resp.StatusCode != 200 { - return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body) - } - var userInfo gitlabUserInfo - err = json.Unmarshal(body, &userInfo) + err := requests.New(userInfoURL.String()). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+s.AccessToken). + UnmarshalInto(&userInfo) if err != nil { - return nil, fmt.Errorf("failed to parse user info: %v", err) + return nil, fmt.Errorf("error getting user info: %v", err) } return &userInfo, nil diff --git a/providers/google.go b/providers/google.go index 5aeb6e2d..fbed12f1 100644 --- a/providers/google.go +++ b/providers/google.go @@ -9,13 +9,13 @@ import ( "fmt" "io" "io/ioutil" - "net/http" "net/url" "strings" "time" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" "google.golang.org/api/googleapi" @@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( params.Add("client_secret", clientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } var jsonResponse struct { AccessToken string `json:"access_token"` @@ -145,10 +123,17 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &jsonResponse) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } + c, err := claimsFromIDToken(jsonResponse.IDToken) if err != nil { return @@ -283,38 +268,23 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st params.Add("client_secret", clientSecret) params.Add("refresh_token", refreshToken) params.Add("grant_type", "refresh_token") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } var data struct { AccessToken string `json:"access_token"` ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` } - err = json.Unmarshal(body, &data) + + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + UnmarshalInto(&data) if err != nil { - return + return "", "", 0, err } + token = data.AccessToken idToken = data.IDToken expires = time.Duration(data.ExpiresIn) * time.Second diff --git a/providers/internal_util.go b/providers/internal_util.go index f9bdc304..361e27c7 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -56,7 +56,11 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h params := url.Values{"access_token": {accessToken}} endpoint = endpoint + "?" + params.Encode() } - resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header) + + resp, err := requests.New(endpoint). + WithContext(ctx). + WithHeaders(header). + Do() if err != nil { logger.Printf("GET %s", stripToken(endpoint)) logger.Printf("token validation request failed: %s", err) diff --git a/providers/keycloak.go b/providers/keycloak.go index 414c4973..78206de8 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -2,7 +2,6 @@ package providers import ( "context" - "net/http" "net/url" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" @@ -51,14 +50,10 @@ func (p *KeycloakProvider) SetGroup(group string) { } func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - - req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil) - req.Header.Set("Authorization", "Bearer "+s.AccessToken) - if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - json, err := requests.Request(req) + json, err := requests.New(p.ValidateURL.String()). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+s.AccessToken). + UnmarshalJSON() if err != nil { logger.Printf("failed making request %s", err) return "", err diff --git a/providers/linkedin.go b/providers/linkedin.go index 6cc24239..8dcd3c9d 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -58,13 +58,12 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess if s.AccessToken == "" { return "", errors.New("missing access token") } - req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil) - if err != nil { - return "", err - } - req.Header = getLinkedInHeader(s.AccessToken) - json, err := requests.Request(req) + requestURL := p.ProfileURL.String() + "?format=json" + json, err := requests.New(requestURL). + WithContext(ctx). + WithHeaders(getLinkedInHeader(s.AccessToken)). + UnmarshalJSON() if err != nil { return "", err } diff --git a/providers/logingov.go b/providers/logingov.go index 46027172..8846a8f2 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -15,6 +15,7 @@ import ( "github.com/dgrijalva/jwt-go" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" "gopkg.in/square/go-jose.v2" ) @@ -128,51 +129,33 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) { return } -func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) { - // query the user info endpoint for user attributes - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil) - if err != nil { - return - } - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body) - return - } - +func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) { // parse the user attributes from the data we got and make sure that // the email address has been validated. var emailData struct { Email string `json:"email"` EmailVerified bool `json:"email_verified"` } - err = json.Unmarshal(body, &emailData) + + // query the user info endpoint for user attributes + err := requests.New(userInfoEndpoint). + WithContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + UnmarshalInto(&emailData) if err != nil { - return + return "", err } - if emailData.Email == "" { - err = fmt.Errorf("missing email") - return + + email := emailData.Email + if email == "" { + return "", fmt.Errorf("missing email") } - email = emailData.Email + if !emailData.EmailVerified { - err = fmt.Errorf("email %s not listed as verified", email) - return + return "", fmt.Errorf("email %s not listed as verified", email) } - return + + return email, nil } // Redeem exchanges the OAuth2 authentication token for an ID token @@ -201,30 +184,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) params.Add("code", code) params.Add("grant_type", "authorization_code") - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - var body []byte - body, err = ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return - } - - if resp.StatusCode != 200 { - err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) - return - } - // Get the token from the body that we got from the token endpoint. var jsonResponse struct { AccessToken string `json:"access_token"` @@ -232,9 +191,14 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` } - err = json.Unmarshal(body, &jsonResponse) + err = requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + UnmarshalInto(&jsonResponse) if err != nil { - return + return nil, err } // check nonce here diff --git a/providers/nextcloud.go b/providers/nextcloud.go index d51b7183..844bbff3 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) @@ -31,18 +30,14 @@ func getNextcloudHeader(accessToken string) http.Header { // GetEmailAddress returns the Account email address func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { - req, err := http.NewRequestWithContext(ctx, "GET", - p.ValidateURL.String(), nil) + json, err := requests.New(p.ValidateURL.String()). + WithContext(ctx). + WithHeaders(getNextcloudHeader(s.AccessToken)). + UnmarshalJSON() if err != nil { - logger.Printf("failed building request %s", err) - return "", err - } - req.Header = getNextcloudHeader(s.AccessToken) - json, err := requests.Request(req) - if err != nil { - logger.Printf("failed making request %s", err) - return "", err + return "", fmt.Errorf("error making request: %v", err) } + email, err := json.Get("ocs").Get("data").Get("email").String() return email, err } diff --git a/providers/oidc.go b/providers/oidc.go index bc98dbb8..08e0c3e8 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -256,13 +256,10 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. // If the userinfo endpoint profileURL is defined, then there is a chance the userinfo // contents at the profileURL contains the email. // Make a query to the userinfo endpoint, and attempt to locate the email from there. - req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil) - if err != nil { - return nil, err - } - req.Header = getOIDCHeader(accessToken) - - respJSON, err := requests.Request(req) + respJSON, err := requests.New(profileURL). + WithContext(ctx). + WithHeaders(getOIDCHeader(accessToken)). + UnmarshalJSON() if err != nil { return nil, err } diff --git a/providers/provider_default.go b/providers/provider_default.go index 14cec9fe..f29e8c4a 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -7,13 +7,13 @@ import ( "errors" "fmt" "io/ioutil" - "net/http" "net/url" "time" "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" ) var _ Provider = (*ProviderData)(nil) @@ -39,18 +39,16 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s params.Add("resource", p.ProtectedResource.String()) } - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) - if err != nil { - return - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - var resp *http.Response - resp, err = http.DefaultClient.Do(req) + resp, err := requests.New(p.RedeemURL.String()). + WithContext(ctx). + WithMethod("POST"). + WithBody(bytes.NewBufferString(params.Encode())). + SetHeader("Content-Type", "application/x-www-form-urlencoded"). + Do() if err != nil { return nil, err } + var body []byte body, err = ioutil.ReadAll(resp.Body) resp.Body.Close()