Create generic Authorization Header constructor
This commit is contained in:
		
							parent
							
								
									9a338d8a34
								
							
						
					
					
						commit
						d05e08cba3
					
				|  | @ -11,6 +11,7 @@ | |||
| 
 | ||||
| ## Changes since v6.0.0 | ||||
| 
 | ||||
| - [#562](https://github.com/oauth2-proxy/oauth2-proxy/pull/562) Create generic Authorization Header constructor (@JoelSpeed) | ||||
| - [#715](https://github.com/oauth2-proxy/oauth2-proxy/pull/715) Ensure session times are not nil before printing them (@JoelSpeed) | ||||
| - [#714](https://github.com/oauth2-proxy/oauth2-proxy/pull/714) Support passwords with Redis session stores (@NickMeves) | ||||
| - [#719](https://github.com/oauth2-proxy/oauth2-proxy/pull/719) Add Gosec fixes to areas that are intermittently flagged on PRs (@NickMeves) | ||||
|  |  | |||
|  | @ -4,7 +4,6 @@ import ( | |||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
|  | @ -154,10 +153,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | |||
| 	return | ||||
| } | ||||
| 
 | ||||
| func getAzureHeader(accessToken string) http.Header { | ||||
| 	header := make(http.Header) | ||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||
| 	return header | ||||
| func makeAzureHeader(accessToken string) http.Header { | ||||
| 	return makeAuthorizationHeader(tokenTypeBearer, accessToken, nil) | ||||
| } | ||||
| 
 | ||||
| func getEmailFromJSON(json *simplejson.Json) (string, error) { | ||||
|  | @ -188,7 +185,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session | |||
| 
 | ||||
| 	json, err := requests.New(p.ProfileURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getAzureHeader(s.AccessToken)). | ||||
| 		WithHeaders(makeAzureHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -3,8 +3,6 @@ package providers | |||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
|  | @ -62,13 +60,6 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider { | |||
| 	return &DigitalOceanProvider{ProviderData: p} | ||||
| } | ||||
| 
 | ||||
| func getDigitalOceanHeader(accessToken string) http.Header { | ||||
| 	header := make(http.Header) | ||||
| 	header.Set("Content-Type", "application/json") | ||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||
| 	return header | ||||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||
| 	if s.AccessToken == "" { | ||||
|  | @ -77,7 +68,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. | |||
| 
 | ||||
| 	json, err := requests.New(p.ProfileURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getDigitalOceanHeader(s.AccessToken)). | ||||
| 		WithHeaders(makeOIDCHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
|  | @ -93,5 +84,5 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. | |||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
| func (p *DigitalOceanProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||
| 	return validateToken(ctx, p, s.AccessToken, getDigitalOceanHeader(s.AccessToken)) | ||||
| 	return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) | ||||
| } | ||||
|  |  | |||
|  | @ -3,8 +3,6 @@ package providers | |||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
|  | @ -63,14 +61,6 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider { | |||
| 	return &FacebookProvider{ProviderData: p} | ||||
| } | ||||
| 
 | ||||
| func getFacebookHeader(accessToken string) http.Header { | ||||
| 	header := make(http.Header) | ||||
| 	header.Set("Accept", "application/json") | ||||
| 	header.Set("x-li-format", "json") | ||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||
| 	return header | ||||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||
| 	if s.AccessToken == "" { | ||||
|  | @ -85,7 +75,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | |||
| 	requestURL := p.ProfileURL.String() + "?fields=name,email" | ||||
| 	err := requests.New(requestURL). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getFacebookHeader(s.AccessToken)). | ||||
| 		WithHeaders(makeOIDCHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&r) | ||||
| 	if err != nil { | ||||
|  | @ -100,5 +90,5 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | |||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
| func (p *FacebookProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||
| 	return validateToken(ctx, p, s.AccessToken, getFacebookHeader(s.AccessToken)) | ||||
| 	return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) | ||||
| } | ||||
|  |  | |||
|  | @ -74,11 +74,12 @@ func NewGitHubProvider(p *ProviderData) *GitHubProvider { | |||
| 	return &GitHubProvider{ProviderData: p} | ||||
| } | ||||
| 
 | ||||
| func getGitHubHeader(accessToken string) http.Header { | ||||
| 	header := make(http.Header) | ||||
| 	header.Set("Accept", "application/vnd.github.v3+json") | ||||
| 	header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) | ||||
| 	return header | ||||
| func makeGitHubHeader(accessToken string) http.Header { | ||||
| 	// extra headers required by the GitHub API when making authenticated requests
 | ||||
| 	extraHeaders := map[string]string{ | ||||
| 		acceptHeader: "application/vnd.github.v3+json", | ||||
| 	} | ||||
| 	return makeAuthorizationHeader(tokenTypeToken, accessToken, extraHeaders) | ||||
| } | ||||
| 
 | ||||
| // SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope
 | ||||
|  | @ -129,7 +130,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, | |||
| 		var op orgsPage | ||||
| 		err := requests.New(endpoint.String()). | ||||
| 			WithContext(ctx). | ||||
| 			WithHeaders(getGitHubHeader(accessToken)). | ||||
| 			WithHeaders(makeGitHubHeader(accessToken)). | ||||
| 			Do(). | ||||
| 			UnmarshalInto(&op) | ||||
| 		if err != nil { | ||||
|  | @ -196,7 +197,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | |||
| 		// nolint:bodyclose
 | ||||
| 		result := requests.New(endpoint.String()). | ||||
| 			WithContext(ctx). | ||||
| 			WithHeaders(getGitHubHeader(accessToken)). | ||||
| 			WithHeaders(makeGitHubHeader(accessToken)). | ||||
| 			Do() | ||||
| 		if result.Error() != nil { | ||||
| 			return false, result.Error() | ||||
|  | @ -296,7 +297,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, | |||
| 	var repo repository | ||||
| 	err := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(accessToken)). | ||||
| 		WithHeaders(makeGitHubHeader(accessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&repo) | ||||
| 	if err != nil { | ||||
|  | @ -324,7 +325,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, | |||
| 
 | ||||
| 	err := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(accessToken)). | ||||
| 		WithHeaders(makeGitHubHeader(accessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&user) | ||||
| 	if err != nil { | ||||
|  | @ -347,7 +348,7 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok | |||
| 	} | ||||
| 	result := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(accessToken)). | ||||
| 		WithHeaders(makeGitHubHeader(accessToken)). | ||||
| 		Do() | ||||
| 	if result.Error() != nil { | ||||
| 		return false, result.Error() | ||||
|  | @ -411,7 +412,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio | |||
| 	} | ||||
| 	err := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(s.AccessToken)). | ||||
| 		WithHeaders(makeGitHubHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&emails) | ||||
| 	if err != nil { | ||||
|  | @ -446,7 +447,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta | |||
| 
 | ||||
| 	err := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(s.AccessToken)). | ||||
| 		WithHeaders(makeGitHubHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&user) | ||||
| 	if err != nil { | ||||
|  | @ -465,7 +466,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta | |||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
| func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||
| 	return validateToken(ctx, p, s.AccessToken, getGitHubHeader(s.AccessToken)) | ||||
| 	return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken)) | ||||
| } | ||||
| 
 | ||||
| // isVerifiedUser
 | ||||
|  |  | |||
|  | @ -3,7 +3,6 @@ package providers | |||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 
 | ||||
|  | @ -62,12 +61,13 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { | |||
| 	return &LinkedInProvider{ProviderData: p} | ||||
| } | ||||
| 
 | ||||
| func getLinkedInHeader(accessToken string) http.Header { | ||||
| 	header := make(http.Header) | ||||
| 	header.Set("Accept", "application/json") | ||||
| 	header.Set("x-li-format", "json") | ||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||
| 	return header | ||||
| func makeLinkedInHeader(accessToken string) http.Header { | ||||
| 	// extra headers required by the LinkedIn API when making authenticated requests
 | ||||
| 	extraHeaders := map[string]string{ | ||||
| 		acceptHeader:  acceptApplicationJSON, | ||||
| 		"x-li-format": "json", | ||||
| 	} | ||||
| 	return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) | ||||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
|  | @ -79,7 +79,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | |||
| 	requestURL := p.ProfileURL.String() + "?format=json" | ||||
| 	json, err := requests.New(requestURL). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getLinkedInHeader(s.AccessToken)). | ||||
| 		WithHeaders(makeLinkedInHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
|  | @ -95,5 +95,5 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | |||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
| func (p *LinkedInProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { | ||||
| 	return validateToken(ctx, p, s.AccessToken, getLinkedInHeader(s.AccessToken)) | ||||
| 	return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) | ||||
| } | ||||
|  |  | |||
|  | @ -3,7 +3,6 @@ package providers | |||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||
|  | @ -24,17 +23,11 @@ func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { | |||
| 	return &NextcloudProvider{ProviderData: p} | ||||
| } | ||||
| 
 | ||||
| func getNextcloudHeader(accessToken string) http.Header { | ||||
| 	header := make(http.Header) | ||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||
| 	return header | ||||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||
| 	json, err := requests.New(p.ValidateURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getNextcloudHeader(s.AccessToken)). | ||||
| 		WithHeaders(makeOIDCHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -3,7 +3,6 @@ package providers | |||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
|  | @ -221,13 +220,6 @@ func (p *OIDCProvider) ValidateSessionState(ctx context.Context, s *sessions.Ses | |||
| 	return err == nil | ||||
| } | ||||
| 
 | ||||
| func getOIDCHeader(accessToken string) http.Header { | ||||
| 	header := make(http.Header) | ||||
| 	header.Set("Accept", "application/json") | ||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||
| 	return header | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { | ||||
| 	claims := &OIDCClaims{} | ||||
| 	// Extract default claims.
 | ||||
|  | @ -263,7 +255,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | |||
| 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | ||||
| 		respJSON, err := requests.New(profileURL). | ||||
| 			WithContext(ctx). | ||||
| 			WithHeaders(getOIDCHeader(token.AccessToken)). | ||||
| 			WithHeaders(makeOIDCHeader(token.AccessToken)). | ||||
| 			Do(). | ||||
| 			UnmarshalJSON() | ||||
| 		if err != nil { | ||||
|  |  | |||
|  | @ -0,0 +1,31 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	tokenTypeBearer = "Bearer" | ||||
| 	tokenTypeToken  = "token" | ||||
| 
 | ||||
| 	acceptHeader          = "Accept" | ||||
| 	acceptApplicationJSON = "application/json" | ||||
| ) | ||||
| 
 | ||||
| func makeAuthorizationHeader(prefix, token string, extraHeaders map[string]string) http.Header { | ||||
| 	header := make(http.Header) | ||||
| 	for key, value := range extraHeaders { | ||||
| 		header.Add(key, value) | ||||
| 	} | ||||
| 	header.Set("Authorization", fmt.Sprintf("%s %s", prefix, token)) | ||||
| 	return header | ||||
| } | ||||
| 
 | ||||
| func makeOIDCHeader(accessToken string) http.Header { | ||||
| 	// extra headers required by the IDP when making authenticated requests
 | ||||
| 	extraHeaders := map[string]string{ | ||||
| 		acceptHeader: acceptApplicationJSON, | ||||
| 	} | ||||
| 	return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) | ||||
| } | ||||
|  | @ -0,0 +1,66 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| func TestMakeAuhtorizationHeader(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		name         string | ||||
| 		prefix       string | ||||
| 		token        string | ||||
| 		extraHeaders map[string]string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:         "With an empty prefix, token and no additional headers", | ||||
| 			prefix:       "", | ||||
| 			token:        "", | ||||
| 			extraHeaders: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:         "With a Bearer token type", | ||||
| 			prefix:       tokenTypeBearer, | ||||
| 			token:        "abcdef", | ||||
| 			extraHeaders: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:         "With a Token token type", | ||||
| 			prefix:       tokenTypeToken, | ||||
| 			token:        "123456", | ||||
| 			extraHeaders: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:   "With a Bearer token type and Accept application/json", | ||||
| 			prefix: tokenTypeToken, | ||||
| 			token:  "abc", | ||||
| 			extraHeaders: map[string]string{ | ||||
| 				acceptHeader: acceptApplicationJSON, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:   "With a Bearer token type and multiple headers", | ||||
| 			prefix: tokenTypeToken, | ||||
| 			token:  "123", | ||||
| 			extraHeaders: map[string]string{ | ||||
| 				acceptHeader: acceptApplicationJSON, | ||||
| 				"foo":        "bar", | ||||
| 				"key":        "value", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			g := NewWithT(t) | ||||
| 
 | ||||
| 			header := makeAuthorizationHeader(tc.prefix, tc.token, tc.extraHeaders) | ||||
| 			g.Expect(header.Get("Authorization")).To(Equal(fmt.Sprintf("%s %s", tc.prefix, tc.token))) | ||||
| 			for k, v := range tc.extraHeaders { | ||||
| 				g.Expect(header.Get(k)).To(Equal(v)) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
		Loading…
	
		Reference in New Issue