Create generic Authorization Header constructor
This commit is contained in:
		
							parent
							
								
									9a338d8a34
								
							
						
					
					
						commit
						d05e08cba3
					
				|  | @ -11,6 +11,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v6.0.0 | ## Changes since v6.0.0 | ||||||
| 
 | 
 | ||||||
|  | - [#562](https://github.com/oauth2-proxy/oauth2-proxy/pull/562) Create generic Authorization Header constructor (@JoelSpeed) | ||||||
| - [#715](https://github.com/oauth2-proxy/oauth2-proxy/pull/715) Ensure session times are not nil before printing them (@JoelSpeed) | - [#715](https://github.com/oauth2-proxy/oauth2-proxy/pull/715) Ensure session times are not nil before printing them (@JoelSpeed) | ||||||
| - [#714](https://github.com/oauth2-proxy/oauth2-proxy/pull/714) Support passwords with Redis session stores (@NickMeves) | - [#714](https://github.com/oauth2-proxy/oauth2-proxy/pull/714) Support passwords with Redis session stores (@NickMeves) | ||||||
| - [#719](https://github.com/oauth2-proxy/oauth2-proxy/pull/719) Add Gosec fixes to areas that are intermittently flagged on PRs (@NickMeves) | - [#719](https://github.com/oauth2-proxy/oauth2-proxy/pull/719) Add Gosec fixes to areas that are intermittently flagged on PRs (@NickMeves) | ||||||
|  |  | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -154,10 +153,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getAzureHeader(accessToken string) http.Header { | func makeAzureHeader(accessToken string) http.Header { | ||||||
| 	header := make(http.Header) | 	return makeAuthorizationHeader(tokenTypeBearer, accessToken, nil) | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) |  | ||||||
| 	return header |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getEmailFromJSON(json *simplejson.Json) (string, error) { | func getEmailFromJSON(json *simplejson.Json) (string, error) { | ||||||
|  | @ -188,7 +185,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session | ||||||
| 
 | 
 | ||||||
| 	json, err := requests.New(p.ProfileURL.String()). | 	json, err := requests.New(p.ProfileURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getAzureHeader(s.AccessToken)). | 		WithHeaders(makeAzureHeader(s.AccessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -3,8 +3,6 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | @ -62,13 +60,6 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider { | ||||||
| 	return &DigitalOceanProvider{ProviderData: p} | 	return &DigitalOceanProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getDigitalOceanHeader(accessToken string) http.Header { |  | ||||||
| 	header := make(http.Header) |  | ||||||
| 	header.Set("Content-Type", "application/json") |  | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) |  | ||||||
| 	return header |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
|  | @ -77,7 +68,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. | ||||||
| 
 | 
 | ||||||
| 	json, err := requests.New(p.ProfileURL.String()). | 	json, err := requests.New(p.ProfileURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getDigitalOceanHeader(s.AccessToken)). | 		WithHeaders(makeOIDCHeader(s.AccessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -93,5 +84,5 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *DigitalOceanProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | func (p *DigitalOceanProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, getDigitalOceanHeader(s.AccessToken)) | 	return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -3,8 +3,6 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | @ -63,14 +61,6 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider { | ||||||
| 	return &FacebookProvider{ProviderData: p} | 	return &FacebookProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getFacebookHeader(accessToken string) http.Header { |  | ||||||
| 	header := make(http.Header) |  | ||||||
| 	header.Set("Accept", "application/json") |  | ||||||
| 	header.Set("x-li-format", "json") |  | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) |  | ||||||
| 	return header |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
|  | @ -85,7 +75,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	requestURL := p.ProfileURL.String() + "?fields=name,email" | 	requestURL := p.ProfileURL.String() + "?fields=name,email" | ||||||
| 	err := requests.New(requestURL). | 	err := requests.New(requestURL). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getFacebookHeader(s.AccessToken)). | 		WithHeaders(makeOIDCHeader(s.AccessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalInto(&r) | 		UnmarshalInto(&r) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -100,5 +90,5 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *FacebookProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | func (p *FacebookProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, getFacebookHeader(s.AccessToken)) | 	return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -74,11 +74,12 @@ func NewGitHubProvider(p *ProviderData) *GitHubProvider { | ||||||
| 	return &GitHubProvider{ProviderData: p} | 	return &GitHubProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getGitHubHeader(accessToken string) http.Header { | func makeGitHubHeader(accessToken string) http.Header { | ||||||
| 	header := make(http.Header) | 	// extra headers required by the GitHub API when making authenticated requests
 | ||||||
| 	header.Set("Accept", "application/vnd.github.v3+json") | 	extraHeaders := map[string]string{ | ||||||
| 	header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) | 		acceptHeader: "application/vnd.github.v3+json", | ||||||
| 	return header | 	} | ||||||
|  | 	return makeAuthorizationHeader(tokenTypeToken, accessToken, extraHeaders) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope
 | // SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope
 | ||||||
|  | @ -129,7 +130,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, | ||||||
| 		var op orgsPage | 		var op orgsPage | ||||||
| 		err := requests.New(endpoint.String()). | 		err := requests.New(endpoint.String()). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
| 			WithHeaders(getGitHubHeader(accessToken)). | 			WithHeaders(makeGitHubHeader(accessToken)). | ||||||
| 			Do(). | 			Do(). | ||||||
| 			UnmarshalInto(&op) | 			UnmarshalInto(&op) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | @ -196,7 +197,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | ||||||
| 		// nolint:bodyclose
 | 		// nolint:bodyclose
 | ||||||
| 		result := requests.New(endpoint.String()). | 		result := requests.New(endpoint.String()). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
| 			WithHeaders(getGitHubHeader(accessToken)). | 			WithHeaders(makeGitHubHeader(accessToken)). | ||||||
| 			Do() | 			Do() | ||||||
| 		if result.Error() != nil { | 		if result.Error() != nil { | ||||||
| 			return false, result.Error() | 			return false, result.Error() | ||||||
|  | @ -296,7 +297,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, | ||||||
| 	var repo repository | 	var repo repository | ||||||
| 	err := requests.New(endpoint.String()). | 	err := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(accessToken)). | 		WithHeaders(makeGitHubHeader(accessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalInto(&repo) | 		UnmarshalInto(&repo) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -324,7 +325,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, | ||||||
| 
 | 
 | ||||||
| 	err := requests.New(endpoint.String()). | 	err := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(accessToken)). | 		WithHeaders(makeGitHubHeader(accessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalInto(&user) | 		UnmarshalInto(&user) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -347,7 +348,7 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok | ||||||
| 	} | 	} | ||||||
| 	result := requests.New(endpoint.String()). | 	result := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(accessToken)). | 		WithHeaders(makeGitHubHeader(accessToken)). | ||||||
| 		Do() | 		Do() | ||||||
| 	if result.Error() != nil { | 	if result.Error() != nil { | ||||||
| 		return false, result.Error() | 		return false, result.Error() | ||||||
|  | @ -411,7 +412,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio | ||||||
| 	} | 	} | ||||||
| 	err := requests.New(endpoint.String()). | 	err := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(s.AccessToken)). | 		WithHeaders(makeGitHubHeader(s.AccessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalInto(&emails) | 		UnmarshalInto(&emails) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -446,7 +447,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta | ||||||
| 
 | 
 | ||||||
| 	err := requests.New(endpoint.String()). | 	err := requests.New(endpoint.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getGitHubHeader(s.AccessToken)). | 		WithHeaders(makeGitHubHeader(s.AccessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalInto(&user) | 		UnmarshalInto(&user) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -465,7 +466,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, getGitHubHeader(s.AccessToken)) | 	return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // isVerifiedUser
 | // isVerifiedUser
 | ||||||
|  |  | ||||||
|  | @ -3,7 +3,6 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
|  | @ -62,12 +61,13 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { | ||||||
| 	return &LinkedInProvider{ProviderData: p} | 	return &LinkedInProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getLinkedInHeader(accessToken string) http.Header { | func makeLinkedInHeader(accessToken string) http.Header { | ||||||
| 	header := make(http.Header) | 	// extra headers required by the LinkedIn API when making authenticated requests
 | ||||||
| 	header.Set("Accept", "application/json") | 	extraHeaders := map[string]string{ | ||||||
| 	header.Set("x-li-format", "json") | 		acceptHeader:  acceptApplicationJSON, | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | 		"x-li-format": "json", | ||||||
| 	return header | 	} | ||||||
|  | 	return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
|  | @ -79,7 +79,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	requestURL := p.ProfileURL.String() + "?format=json" | 	requestURL := p.ProfileURL.String() + "?format=json" | ||||||
| 	json, err := requests.New(requestURL). | 	json, err := requests.New(requestURL). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getLinkedInHeader(s.AccessToken)). | 		WithHeaders(makeLinkedInHeader(s.AccessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -95,5 +95,5 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *LinkedInProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | func (p *LinkedInProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, getLinkedInHeader(s.AccessToken)) | 	return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -3,7 +3,6 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||||
|  | @ -24,17 +23,11 @@ func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { | ||||||
| 	return &NextcloudProvider{ProviderData: p} | 	return &NextcloudProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getNextcloudHeader(accessToken string) http.Header { |  | ||||||
| 	header := make(http.Header) |  | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) |  | ||||||
| 	return header |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||||
| 	json, err := requests.New(p.ValidateURL.String()). | 	json, err := requests.New(p.ValidateURL.String()). | ||||||
| 		WithContext(ctx). | 		WithContext(ctx). | ||||||
| 		WithHeaders(getNextcloudHeader(s.AccessToken)). | 		WithHeaders(makeOIDCHeader(s.AccessToken)). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalJSON() | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -3,7 +3,6 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | @ -221,13 +220,6 @@ func (p *OIDCProvider) ValidateSessionState(ctx context.Context, s *sessions.Ses | ||||||
| 	return err == nil | 	return err == nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getOIDCHeader(accessToken string) http.Header { |  | ||||||
| 	header := make(http.Header) |  | ||||||
| 	header.Set("Accept", "application/json") |  | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) |  | ||||||
| 	return header |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { | func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { | ||||||
| 	claims := &OIDCClaims{} | 	claims := &OIDCClaims{} | ||||||
| 	// Extract default claims.
 | 	// Extract default claims.
 | ||||||
|  | @ -263,7 +255,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | ||||||
| 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | ||||||
| 		respJSON, err := requests.New(profileURL). | 		respJSON, err := requests.New(profileURL). | ||||||
| 			WithContext(ctx). | 			WithContext(ctx). | ||||||
| 			WithHeaders(getOIDCHeader(token.AccessToken)). | 			WithHeaders(makeOIDCHeader(token.AccessToken)). | ||||||
| 			Do(). | 			Do(). | ||||||
| 			UnmarshalJSON() | 			UnmarshalJSON() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  |  | ||||||
|  | @ -0,0 +1,31 @@ | ||||||
|  | package providers | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	tokenTypeBearer = "Bearer" | ||||||
|  | 	tokenTypeToken  = "token" | ||||||
|  | 
 | ||||||
|  | 	acceptHeader          = "Accept" | ||||||
|  | 	acceptApplicationJSON = "application/json" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func makeAuthorizationHeader(prefix, token string, extraHeaders map[string]string) http.Header { | ||||||
|  | 	header := make(http.Header) | ||||||
|  | 	for key, value := range extraHeaders { | ||||||
|  | 		header.Add(key, value) | ||||||
|  | 	} | ||||||
|  | 	header.Set("Authorization", fmt.Sprintf("%s %s", prefix, token)) | ||||||
|  | 	return header | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func makeOIDCHeader(accessToken string) http.Header { | ||||||
|  | 	// extra headers required by the IDP when making authenticated requests
 | ||||||
|  | 	extraHeaders := map[string]string{ | ||||||
|  | 		acceptHeader: acceptApplicationJSON, | ||||||
|  | 	} | ||||||
|  | 	return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) | ||||||
|  | } | ||||||
|  | @ -0,0 +1,66 @@ | ||||||
|  | package providers | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestMakeAuhtorizationHeader(t *testing.T) { | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		name         string | ||||||
|  | 		prefix       string | ||||||
|  | 		token        string | ||||||
|  | 		extraHeaders map[string]string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:         "With an empty prefix, token and no additional headers", | ||||||
|  | 			prefix:       "", | ||||||
|  | 			token:        "", | ||||||
|  | 			extraHeaders: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:         "With a Bearer token type", | ||||||
|  | 			prefix:       tokenTypeBearer, | ||||||
|  | 			token:        "abcdef", | ||||||
|  | 			extraHeaders: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:         "With a Token token type", | ||||||
|  | 			prefix:       tokenTypeToken, | ||||||
|  | 			token:        "123456", | ||||||
|  | 			extraHeaders: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:   "With a Bearer token type and Accept application/json", | ||||||
|  | 			prefix: tokenTypeToken, | ||||||
|  | 			token:  "abc", | ||||||
|  | 			extraHeaders: map[string]string{ | ||||||
|  | 				acceptHeader: acceptApplicationJSON, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:   "With a Bearer token type and multiple headers", | ||||||
|  | 			prefix: tokenTypeToken, | ||||||
|  | 			token:  "123", | ||||||
|  | 			extraHeaders: map[string]string{ | ||||||
|  | 				acceptHeader: acceptApplicationJSON, | ||||||
|  | 				"foo":        "bar", | ||||||
|  | 				"key":        "value", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range testCases { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			g := NewWithT(t) | ||||||
|  | 
 | ||||||
|  | 			header := makeAuthorizationHeader(tc.prefix, tc.token, tc.extraHeaders) | ||||||
|  | 			g.Expect(header.Get("Authorization")).To(Equal(fmt.Sprintf("%s %s", tc.prefix, tc.token))) | ||||||
|  | 			for k, v := range tc.extraHeaders { | ||||||
|  | 				g.Expect(header.Get(k)).To(Equal(v)) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue