Move AllowedGroups to DefaultProvider for default Authorize usage
This commit is contained in:
		
							parent
							
								
									e7ac793044
								
							
						
					
					
						commit
						eb58ea2ed9
					
				|  | @ -105,7 +105,6 @@ type OAuthProxy struct { | |||
| 	trustedIPs              *ip.NetSet | ||||
| 	Banner                  string | ||||
| 	Footer                  string | ||||
| 	AllowedGroups           []string | ||||
| 
 | ||||
| 	sessionChain alice.Chain | ||||
| 	headersChain alice.Chain | ||||
|  | @ -219,7 +218,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		Banner:                  opts.Banner, | ||||
| 		Footer:                  opts.Footer, | ||||
| 		SignInMessage:           buildSignInMessage(opts), | ||||
| 		AllowedGroups:           opts.AllowedGroups, | ||||
| 
 | ||||
| 		basicAuthValidator:  basicAuthValidator, | ||||
| 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | ||||
|  | @ -992,13 +990,12 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R | |||
| 	} | ||||
| 
 | ||||
| 	invalidEmail := session.Email != "" && !p.Validator(session.Email) | ||||
| 	invalidGroups := session != nil && !p.validateGroups(session.Groups) | ||||
| 	authorized, err := p.provider.Authorize(req.Context(), session) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error with authorization: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if invalidEmail || invalidGroups || !authorized { | ||||
| 	if invalidEmail || !authorized { | ||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) | ||||
| 		// Invalid session, clear it
 | ||||
| 		err := p.ClearSessionCookie(rw, req) | ||||
|  | @ -1037,23 +1034,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { | |||
| 	rw.Header().Set("Content-Type", applicationJSON) | ||||
| 	rw.WriteHeader(code) | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) validateGroups(groups []string) bool { | ||||
| 	if len(p.AllowedGroups) == 0 { | ||||
| 		return true | ||||
| 	} | ||||
| 
 | ||||
| 	allowedGroups := map[string]struct{}{} | ||||
| 
 | ||||
| 	for _, group := range p.AllowedGroups { | ||||
| 		allowedGroups[group] = struct{}{} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, group := range groups { | ||||
| 		if _, ok := allowedGroups[group]; ok { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false | ||||
| } | ||||
|  |  | |||
|  | @ -976,8 +976,10 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi | |||
| 		return nil, err | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   opts.providerValidateCookieResponse, | ||||
| 	} | ||||
| 	pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.AllowedGroups) | ||||
| 
 | ||||
| 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | ||||
| 	// access_token validation.
 | ||||
|  | @ -1132,10 +1134,7 @@ func TestUserInfoEndpointAccepted(t *testing.T) { | |||
| 	err = test.SaveSession(startSession) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | ||||
| 	assert.Equal(t, http.StatusOK, test.rw.Code) | ||||
| 	bodyBytes, _ := ioutil.ReadAll(test.rw.Body) | ||||
| 	assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes)) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { | ||||
|  | @ -1284,6 +1283,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | |||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   true, | ||||
| 	} | ||||
| 
 | ||||
|  | @ -1376,6 +1376,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { | |||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   true, | ||||
| 	} | ||||
| 
 | ||||
|  | @ -1455,6 +1456,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { | |||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   true, | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -233,6 +233,8 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | |||
| 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | ||||
| 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | ||||
| 
 | ||||
| 	p.SetAllowedGroups(o.AllowedGroups) | ||||
| 
 | ||||
| 	provider := providers.New(o.ProviderType, p) | ||||
| 	if provider == nil { | ||||
| 		msgs = append(msgs, fmt.Sprintf("invalid setting: provider '%s' is not available", o.ProviderType)) | ||||
|  |  | |||
|  | @ -26,6 +26,10 @@ type ProviderData struct { | |||
| 	ClientSecretFile string | ||||
| 	Scope            string | ||||
| 	Prompt           string | ||||
| 
 | ||||
| 	// Universal Group authorization data structure
 | ||||
| 	// any provider can set to consume
 | ||||
| 	AllowedGroups map[string]struct{} | ||||
| } | ||||
| 
 | ||||
| // Data returns the ProviderData
 | ||||
|  | @ -45,6 +49,15 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) { | |||
| 	return string(fileClientSecret), nil | ||||
| } | ||||
| 
 | ||||
| // SetAllowedGroups organizes a group list into the AllowedGroups map
 | ||||
| // to be consumed by Authorize implementations
 | ||||
| func (p *ProviderData) SetAllowedGroups(groups []string) { | ||||
| 	p.AllowedGroups = map[string]struct{}{} | ||||
| 	for _, group := range groups { | ||||
| 		p.AllowedGroups[group] = struct{}{} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type providerDefaults struct { | ||||
| 	name        string | ||||
| 	loginURL    *url.URL | ||||
|  |  | |||
|  | @ -92,12 +92,6 @@ func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionSta | |||
| 	return "", ErrNotImplemented | ||||
| } | ||||
| 
 | ||||
| // ValidateGroup validates that the provided email exists in the configured provider
 | ||||
| // email group(s).
 | ||||
| func (p *ProviderData) ValidateGroup(_ string) bool { | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| // EnrichSessionState is called after Redeem to allow providers to enrich session fields
 | ||||
| // such as User, Email, Groups with provider specific API calls.
 | ||||
| func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error { | ||||
|  | @ -107,7 +101,17 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session | |||
| // Authorize performs global authorization on an authenticated session.
 | ||||
| // This is not used for fine-grained per route authorization rules.
 | ||||
| func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 	if len(p.AllowedGroups) == 0 { | ||||
| 		return true, nil | ||||
| 	} | ||||
| 
 | ||||
| 	for _, group := range s.Groups { | ||||
| 		if _, ok := p.AllowedGroups[group]; ok { | ||||
| 			return true, nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return false, nil | ||||
| } | ||||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/gomega" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -53,3 +54,53 @@ func TestEnrichSessionState(t *testing.T) { | |||
| 	s := &sessions.SessionState{} | ||||
| 	assert.NoError(t, p.EnrichSessionState(context.Background(), s)) | ||||
| } | ||||
| 
 | ||||
| func TestProviderDataAuthorize(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		name          string | ||||
| 		allowedGroups []string | ||||
| 		groups        []string | ||||
| 		expectedAuthZ bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:          "NoAllowedGroups", | ||||
| 			allowedGroups: []string{}, | ||||
| 			groups:        []string{}, | ||||
| 			expectedAuthZ: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "NoAllowedGroupsUserHasGroups", | ||||
| 			allowedGroups: []string{}, | ||||
| 			groups:        []string{"foo", "bar"}, | ||||
| 			expectedAuthZ: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "UserInAllowedGroup", | ||||
| 			allowedGroups: []string{"foo"}, | ||||
| 			groups:        []string{"foo", "bar"}, | ||||
| 			expectedAuthZ: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "UserNotInAllowedGroup", | ||||
| 			allowedGroups: []string{"bar"}, | ||||
| 			groups:        []string{"baz", "foo"}, | ||||
| 			expectedAuthZ: false, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			g := NewWithT(t) | ||||
| 
 | ||||
| 			session := &sessions.SessionState{ | ||||
| 				Groups: tc.groups, | ||||
| 			} | ||||
| 			p := &ProviderData{} | ||||
| 			p.SetAllowedGroups(tc.allowedGroups) | ||||
| 
 | ||||
| 			authorized, err := p.Authorize(context.Background(), session) | ||||
| 			g.Expect(err).ToNot(HaveOccurred()) | ||||
| 			g.Expect(authorized).To(Equal(tc.expectedAuthZ)) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue