Move AllowedGroups to DefaultProvider for default Authorize usage
This commit is contained in:
		
							parent
							
								
									e7ac793044
								
							
						
					
					
						commit
						eb58ea2ed9
					
				|  | @ -105,7 +105,6 @@ type OAuthProxy struct { | ||||||
| 	trustedIPs              *ip.NetSet | 	trustedIPs              *ip.NetSet | ||||||
| 	Banner                  string | 	Banner                  string | ||||||
| 	Footer                  string | 	Footer                  string | ||||||
| 	AllowedGroups           []string |  | ||||||
| 
 | 
 | ||||||
| 	sessionChain alice.Chain | 	sessionChain alice.Chain | ||||||
| 	headersChain alice.Chain | 	headersChain alice.Chain | ||||||
|  | @ -219,7 +218,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		Banner:                  opts.Banner, | 		Banner:                  opts.Banner, | ||||||
| 		Footer:                  opts.Footer, | 		Footer:                  opts.Footer, | ||||||
| 		SignInMessage:           buildSignInMessage(opts), | 		SignInMessage:           buildSignInMessage(opts), | ||||||
| 		AllowedGroups:           opts.AllowedGroups, |  | ||||||
| 
 | 
 | ||||||
| 		basicAuthValidator:  basicAuthValidator, | 		basicAuthValidator:  basicAuthValidator, | ||||||
| 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | ||||||
|  | @ -992,13 +990,12 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	invalidEmail := session.Email != "" && !p.Validator(session.Email) | 	invalidEmail := session.Email != "" && !p.Validator(session.Email) | ||||||
| 	invalidGroups := session != nil && !p.validateGroups(session.Groups) |  | ||||||
| 	authorized, err := p.provider.Authorize(req.Context(), session) | 	authorized, err := p.provider.Authorize(req.Context(), session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Errorf("Error with authorization: %v", err) | 		logger.Errorf("Error with authorization: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if invalidEmail || invalidGroups || !authorized { | 	if invalidEmail || !authorized { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) | 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) | ||||||
| 		// Invalid session, clear it
 | 		// Invalid session, clear it
 | ||||||
| 		err := p.ClearSessionCookie(rw, req) | 		err := p.ClearSessionCookie(rw, req) | ||||||
|  | @ -1037,23 +1034,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { | ||||||
| 	rw.Header().Set("Content-Type", applicationJSON) | 	rw.Header().Set("Content-Type", applicationJSON) | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) validateGroups(groups []string) bool { |  | ||||||
| 	if len(p.AllowedGroups) == 0 { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	allowedGroups := map[string]struct{}{} |  | ||||||
| 
 |  | ||||||
| 	for _, group := range p.AllowedGroups { |  | ||||||
| 		allowedGroups[group] = struct{}{} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, group := range groups { |  | ||||||
| 		if _, ok := allowedGroups[group]; ok { |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   opts.providerValidateCookieResponse, | 		ValidToken:   opts.providerValidateCookieResponse, | ||||||
| 	} | 	} | ||||||
|  | 	pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups) | ||||||
| 
 | 
 | ||||||
| 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | ||||||
| 	// access_token validation.
 | 	// access_token validation.
 | ||||||
|  | @ -1132,10 +1134,7 @@ func TestUserInfoEndpointAccepted(t *testing.T) { | ||||||
| 	err = test.SaveSession(startSession) | 	err = test.SaveSession(startSession) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | 	return | ||||||
| 	assert.Equal(t, http.StatusOK, test.rw.Code) |  | ||||||
| 	bodyBytes, _ := ioutil.ReadAll(test.rw.Body) |  | ||||||
| 	assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes)) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { | func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { | ||||||
|  | @ -1284,6 +1283,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   true, | 		ValidToken:   true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -1376,6 +1376,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   true, | 		ValidToken:   true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -1455,6 +1456,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   true, | 		ValidToken:   true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | ||||||
| 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | ||||||
| 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | ||||||
| 
 | 
 | ||||||
|  | 	p.SetAllowedGroups(o.AllowedGroups) | ||||||
|  | 
 | ||||||
| 	provider := providers.New(o.ProviderType, p) | 	provider := providers.New(o.ProviderType, p) | ||||||
| 	if provider == nil { | 	if provider == nil { | ||||||
| 		msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) | 		msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) | ||||||
|  |  | ||||||
|  | @ -26,6 +26,10 @@ type ProviderData struct { | ||||||
| 	ClientSecretFile string | 	ClientSecretFile string | ||||||
| 	Scope            string | 	Scope            string | ||||||
| 	Prompt           string | 	Prompt           string | ||||||
|  | 
 | ||||||
|  | 	// Universal Group authorization data structure
 | ||||||
|  | 	// any provider can set to consume
 | ||||||
|  | 	AllowedGroups map[string]struct{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Data returns the ProviderData
 | // Data returns the ProviderData
 | ||||||
|  | @ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) { | ||||||
| 	return string(fileClientSecret), nil | 	return string(fileClientSecret), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SetAllowedGroups organizes a group list into the AllowedGroups map
 | ||||||
|  | // to be consumed by Authorize implementations
 | ||||||
|  | func (p *ProviderData) SetAllowedGroups(groups []string) { | ||||||
|  | 	p.AllowedGroups = map[string]struct{}{} | ||||||
|  | 	for _, group := range groups { | ||||||
|  | 		p.AllowedGroups[group] = struct{}{} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type providerDefaults struct { | type providerDefaults struct { | ||||||
| 	name        string | 	name        string | ||||||
| 	loginURL    *url.URL | 	loginURL    *url.URL | ||||||
|  |  | ||||||
|  | @ -92,12 +92,6 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta | ||||||
| 	return "", ErrNotImplemented | 	return "", ErrNotImplemented | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateGroup validates that the provided email exists in the configured provider
 |  | ||||||
| // email group(s).
 |  | ||||||
| func (p *ProviderData) ValidateGroup(_ string) bool { |  | ||||||
| 	return true |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // EnrichSessionState is called after Redeem to allow providers to enrich session fields
 | // EnrichSessionState is called after Redeem to allow providers to enrich session fields
 | ||||||
| // such as User, Email, Groups with provider specific API calls.
 | // such as User, Email, Groups with provider specific API calls.
 | ||||||
| func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { | func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { | ||||||
|  | @ -107,7 +101,17 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session | ||||||
| // Authorize performs global authorization on an authenticated session.
 | // Authorize performs global authorization on an authenticated session.
 | ||||||
| // This is not used for fine-grained per route authorization rules.
 | // This is not used for fine-grained per route authorization rules.
 | ||||||
| func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) { | func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | 	if len(p.AllowedGroups) == 0 { | ||||||
| 		return true, nil | 		return true, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, group := range s.Groups { | ||||||
|  | 		if _, ok := p.AllowedGroups[group]; ok { | ||||||
|  | 			return true, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) { | ||||||
| 	s := &sessions.SessionState{} | 	s := &sessions.SessionState{} | ||||||
| 	assert.NoError(t, p.EnrichSessionState(context.Background(), s)) | 	assert.NoError(t, p.EnrichSessionState(context.Background(), s)) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestProviderDataAuthorize(t *testing.T) { | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		name          string | ||||||
|  | 		allowedGroups []string | ||||||
|  | 		groups        []string | ||||||
|  | 		expectedAuthZ bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:          "NoAllowedGroups", | ||||||
|  | 			allowedGroups: []string{}, | ||||||
|  | 			groups:        []string{}, | ||||||
|  | 			expectedAuthZ: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "NoAllowedGroupsUserHasGroups", | ||||||
|  | 			allowedGroups: []string{}, | ||||||
|  | 			groups:        []string{"foo", "bar"}, | ||||||
|  | 			expectedAuthZ: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "UserInAllowedGroup", | ||||||
|  | 			allowedGroups: []string{"foo"}, | ||||||
|  | 			groups:        []string{"foo", "bar"}, | ||||||
|  | 			expectedAuthZ: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "UserNotInAllowedGroup", | ||||||
|  | 			allowedGroups: []string{"bar"}, | ||||||
|  | 			groups:        []string{"baz", "foo"}, | ||||||
|  | 			expectedAuthZ: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range testCases { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			g := NewWithT(t) | ||||||
|  | 
 | ||||||
|  | 			session := &sessions.SessionState{ | ||||||
|  | 				Groups: tc.groups, | ||||||
|  | 			} | ||||||
|  | 			p := &ProviderData{} | ||||||
|  | 			p.SetAllowedGroups(tc.allowedGroups) | ||||||
|  | 
 | ||||||
|  | 			authorized, err := p.Authorize(context.Background(), session) | ||||||
|  | 			g.Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			g.Expect(authorized).To(Equal(tc.expectedAuthZ)) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue