Migrate all requests to result pattern
This commit is contained in:
		
							parent
							
								
									d0b6c04960
								
							
						
					
					
						commit
						de9e65a63a
					
				|  | @ -86,6 +86,7 @@ func Validate(o *options.Options) error { | ||||||
| 			requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" | 			requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" | ||||||
| 			body, err := requests.New(requestURL). | 			body, err := requests.New(requestURL). | ||||||
| 				WithContext(ctx). | 				WithContext(ctx). | ||||||
|  | 				Do(). | ||||||
| 				UnmarshalJSON() | 				UnmarshalJSON() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logger.Printf("error: failed to discover OIDC configuration: %v", err) | 				logger.Printf("error: failed to discover OIDC configuration: %v", err) | ||||||
|  | @ -384,11 +385,9 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// Try as JWKS URI
 | 		// Try as JWKS URI
 | ||||||
| 		jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" | 		jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" | ||||||
| 		resp, err := requests.New(jwksURI).Do() | 		if err := requests.New(jwksURI).Do().Error(); err != nil { | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 		resp.Body.Close() |  | ||||||
| 
 | 
 | ||||||
| 		verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) | 		verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) | ||||||
| 	} else { | 	} else { | ||||||
|  |  | ||||||
|  | @ -101,6 +101,7 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 		WithMethod("POST"). | 		WithMethod("POST"). | ||||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&jsonResponse) | 		UnmarshalInto(&jsonResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -153,6 +154,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session | ||||||
| 	json, err := requests.New(p.ProfileURL.String()). | 	json, err := requests.New(p.ProfileURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getAzureHeader(s.AccessToken)). | 		WithHeaders(getAzureHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
|  |  | ||||||
|  | @ -88,6 +88,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | ||||||
| 	requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken | 	requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken | ||||||
| 	err := requests.New(requestURL). | 	err := requests.New(requestURL). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&emails) | 		UnmarshalInto(&emails) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("failed making request: %v", err) | 		logger.Printf("failed making request: %v", err) | ||||||
|  | @ -103,6 +104,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | ||||||
| 
 | 
 | ||||||
| 		err := requests.New(requestURL). | 		err := requests.New(requestURL). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
|  | 			Do(). | ||||||
| 			UnmarshalInto(&teams) | 			UnmarshalInto(&teams) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Printf("failed requesting teams membership: %v", err) | 			logger.Printf("failed requesting teams membership: %v", err) | ||||||
|  | @ -132,6 +134,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | ||||||
| 
 | 
 | ||||||
| 		err := requests.New(requestURL). | 		err := requests.New(requestURL). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
|  | 			Do(). | ||||||
| 			UnmarshalInto(&repositories) | 			UnmarshalInto(&repositories) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Printf("failed checking repository access: %v", err) | 			logger.Printf("failed checking repository access: %v", err) | ||||||
|  |  | ||||||
|  | @ -64,6 +64,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. | ||||||
| 	json, err := requests.New(p.ProfileURL.String()). | 	json, err := requests.New(p.ProfileURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getDigitalOceanHeader(s.AccessToken)). | 		WithHeaders(getDigitalOceanHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
|  |  | ||||||
|  | @ -72,6 +72,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	err := requests.New(requestURL). | 	err := requests.New(requestURL). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getFacebookHeader(s.AccessToken)). | 		WithHeaders(getFacebookHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&r) | 		UnmarshalInto(&r) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
|  |  | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"path" | 	"path" | ||||||
|  | @ -116,6 +115,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, | ||||||
| 		err := requests.New(endpoint.String()). | 		err := requests.New(endpoint.String()). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
| 			WithHeaders(getGitHubHeader(accessToken)). | 			WithHeaders(getGitHubHeader(accessToken)). | ||||||
|  | 			Do(). | ||||||
| 			UnmarshalInto(&op) | 			UnmarshalInto(&op) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return false, err | 			return false, err | ||||||
|  | @ -179,12 +179,12 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | ||||||
| 		// bodyclose cannot detect that the body is being closed later in requests.Into,
 | 		// bodyclose cannot detect that the body is being closed later in requests.Into,
 | ||||||
| 		// so have to skip the linting for the next line.
 | 		// so have to skip the linting for the next line.
 | ||||||
| 		// nolint:bodyclose
 | 		// nolint:bodyclose
 | ||||||
| 		resp, err := requests.New(endpoint.String()). | 		result := requests.New(endpoint.String()). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
| 			WithHeaders(getGitHubHeader(accessToken)). | 			WithHeaders(getGitHubHeader(accessToken)). | ||||||
| 			Do() | 			Do() | ||||||
| 		if err != nil { | 		if result.Error() != nil { | ||||||
| 			return false, err | 			return false, result.Error() | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if last == 0 { | 		if last == 0 { | ||||||
|  | @ -200,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | ||||||
| 			// link header at last page (doesn't exist last info)
 | 			// link header at last page (doesn't exist last info)
 | ||||||
| 			// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
 | 			// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
 | ||||||
| 
 | 
 | ||||||
| 			link := resp.Header.Get("Link") | 			link := result.Headers().Get("Link") | ||||||
| 			rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`) | 			rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`) | ||||||
| 			i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) | 			i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) | ||||||
| 
 | 
 | ||||||
|  | @ -211,7 +211,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		var tp teamsPage | 		var tp teamsPage | ||||||
| 		if err := requests.UnmarshalInto(resp, &tp); err != nil { | 		if err := result.UnmarshalInto(&tp); err != nil { | ||||||
| 			return false, err | 			return false, err | ||||||
| 		} | 		} | ||||||
| 		if len(tp) == 0 { | 		if len(tp) == 0 { | ||||||
|  | @ -282,6 +282,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, | ||||||
| 	err := requests.New(endpoint.String()). | 	err := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(accessToken)). | 		WithHeaders(getGitHubHeader(accessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&repo) | 		UnmarshalInto(&repo) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
|  | @ -309,6 +310,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, | ||||||
| 	err := requests.New(endpoint.String()). | 	err := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(accessToken)). | 		WithHeaders(getGitHubHeader(accessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&user) | 		UnmarshalInto(&user) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
|  | @ -328,26 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok | ||||||
| 		Host:   p.ValidateURL.Host, | 		Host:   p.ValidateURL.Host, | ||||||
| 		Path:   path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), | 		Path:   path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), | ||||||
| 	} | 	} | ||||||
| 	resp, err := requests.New(endpoint.String()). | 	result := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(accessToken)). | 		WithHeaders(getGitHubHeader(accessToken)). | ||||||
| 		Do() | 		Do() | ||||||
| 	if err != nil { | 	if result.Error() != nil { | ||||||
| 		return false, err | 		return false, result.Error() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	body, err := ioutil.ReadAll(resp.Body) | 	if result.StatusCode() != 204 { | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 204 { |  | ||||||
| 		return false, fmt.Errorf("got %d from %q %s", | 		return false, fmt.Errorf("got %d from %q %s", | ||||||
| 			resp.StatusCode, endpoint.String(), body) | 			result.StatusCode(), endpoint.String(), result.Body()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) | 	logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body()) | ||||||
| 
 | 
 | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
|  | @ -401,6 +397,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio | ||||||
| 	err := requests.New(endpoint.String()). | 	err := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(s.AccessToken)). | 		WithHeaders(getGitHubHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&emails) | 		UnmarshalInto(&emails) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
|  | @ -435,6 +432,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta | ||||||
| 	err := requests.New(endpoint.String()). | 	err := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(s.AccessToken)). | 		WithHeaders(getGitHubHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&user) | 		UnmarshalInto(&user) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
|  |  | ||||||
|  | @ -133,6 +133,7 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta | ||||||
| 	err := requests.New(userInfoURL.String()). | 	err := requests.New(userInfoURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		SetHeader("Authorization", "Bearer "+s.AccessToken). | 		SetHeader("Authorization", "Bearer "+s.AccessToken). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&userInfo) | 		UnmarshalInto(&userInfo) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("error getting user info: %v", err) | 		return nil, fmt.Errorf("error getting user info: %v", err) | ||||||
|  |  | ||||||
|  | @ -129,6 +129,7 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | ||||||
| 		WithMethod("POST"). | 		WithMethod("POST"). | ||||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&jsonResponse) | 		UnmarshalInto(&jsonResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -280,6 +281,7 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st | ||||||
| 		WithMethod("POST"). | 		WithMethod("POST"). | ||||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&data) | 		UnmarshalInto(&data) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", "", 0, err | 		return "", "", 0, err | ||||||
|  |  | ||||||
|  | @ -2,7 +2,6 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
|  | @ -57,23 +56,21 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h | ||||||
| 		endpoint = endpoint + "?" + params.Encode() | 		endpoint = endpoint + "?" + params.Encode() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	resp, err := requests.New(endpoint). | 	result := requests.New(endpoint). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(header). | 		WithHeaders(header). | ||||||
| 		Do() | 		Do() | ||||||
| 	if err != nil { | 	if result.Error() != nil { | ||||||
| 		logger.Printf("GET %s", stripToken(endpoint)) | 		logger.Printf("GET %s", stripToken(endpoint)) | ||||||
| 		logger.Printf("token validation request failed: %s", err) | 		logger.Printf("token validation request failed: %s", result.Error()) | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	body, _ := ioutil.ReadAll(resp.Body) | 	logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body()) | ||||||
| 	resp.Body.Close() |  | ||||||
| 	logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) |  | ||||||
| 
 | 
 | ||||||
| 	if resp.StatusCode == 200 { | 	if result.StatusCode() == 200 { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) | 	logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()) | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -53,6 +53,7 @@ func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	json, err := requests.New(p.ValidateURL.String()). | 	json, err := requests.New(p.ValidateURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		SetHeader("Authorization", "Bearer "+s.AccessToken). | 		SetHeader("Authorization", "Bearer "+s.AccessToken). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("failed making request %s", err) | 		logger.Printf("failed making request %s", err) | ||||||
|  |  | ||||||
|  | @ -63,6 +63,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	json, err := requests.New(requestURL). | 	json, err := requests.New(requestURL). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getLinkedInHeader(s.AccessToken)). | 		WithHeaders(getLinkedInHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
|  |  | ||||||
|  | @ -141,6 +141,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint | ||||||
| 	err := requests.New(userInfoEndpoint). | 	err := requests.New(userInfoEndpoint). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		SetHeader("Authorization", "Bearer "+accessToken). | 		SetHeader("Authorization", "Bearer "+accessToken). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&emailData) | 		UnmarshalInto(&emailData) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
|  | @ -196,6 +197,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | ||||||
| 		WithMethod("POST"). | 		WithMethod("POST"). | ||||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalInto(&jsonResponse) | 		UnmarshalInto(&jsonResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  |  | ||||||
|  | @ -33,6 +33,7 @@ func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | ||||||
| 	json, err := requests.New(p.ValidateURL.String()). | 	json, err := requests.New(p.ValidateURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getNextcloudHeader(s.AccessToken)). | 		WithHeaders(getNextcloudHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", fmt.Errorf("error making request: %v", err) | 		return "", fmt.Errorf("error making request: %v", err) | ||||||
|  |  | ||||||
|  | @ -259,6 +259,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | ||||||
| 		respJSON, err := requests.New(profileURL). | 		respJSON, err := requests.New(profileURL). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
| 			WithHeaders(getOIDCHeader(accessToken)). | 			WithHeaders(getOIDCHeader(accessToken)). | ||||||
|  | 			Do(). | ||||||
| 			UnmarshalJSON() | 			UnmarshalJSON() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
|  |  | ||||||
|  | @ -3,10 +3,8 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | @ -39,33 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 		params.Add("resource", p.ProtectedResource.String()) | 		params.Add("resource", p.ProtectedResource.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	resp, err := requests.New(p.RedeemURL.String()). | 	result := requests.New(p.RedeemURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithMethod("POST"). | 		WithMethod("POST"). | ||||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
| 		Do() | 		Do() | ||||||
| 	if err != nil { | 	if result.Error() != nil { | ||||||
| 		return nil, err | 		return nil, result.Error() | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var body []byte |  | ||||||
| 	body, err = ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// blindly try json and x-www-form-urlencoded
 | 	// blindly try json and x-www-form-urlencoded
 | ||||||
| 	var jsonResponse struct { | 	var jsonResponse struct { | ||||||
| 		AccessToken string `json:"access_token"` | 		AccessToken string `json:"access_token"` | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &jsonResponse) | 	err = result.UnmarshalInto(&jsonResponse) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		s = &sessions.SessionState{ | 		s = &sessions.SessionState{ | ||||||
| 			AccessToken: jsonResponse.AccessToken, | 			AccessToken: jsonResponse.AccessToken, | ||||||
|  | @ -74,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var v url.Values | 	var v url.Values | ||||||
| 	v, err = url.ParseQuery(string(body)) | 	v, err = url.ParseQuery(string(result.Body())) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | @ -82,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 		created := time.Now() | 		created := time.Now() | ||||||
| 		s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} | 		s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} | ||||||
| 	} else { | 	} else { | ||||||
| 		err = fmt.Errorf("no access token found %s", body) | 		err = fmt.Errorf("no access token found %s", result.Body()) | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue