Update unit tests for ValidateGroup
This commit is contained in:
		
							parent
							
								
									bd651df3c2
								
							
						
					
					
						commit
						3881955605
					
				|  | @ -231,6 +231,7 @@ type TestProvider struct { | ||||||
| 	*providers.ProviderData | 	*providers.ProviderData | ||||||
| 	EmailAddress   string | 	EmailAddress   string | ||||||
| 	ValidToken     bool | 	ValidToken     bool | ||||||
|  | 	GroupValidator func(string) bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { | func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { | ||||||
|  | @ -255,6 +256,9 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { | ||||||
| 			Scope: "profile.email", | 			Scope: "profile.email", | ||||||
| 		}, | 		}, | ||||||
| 		EmailAddress: emailAddress, | 		EmailAddress: emailAddress, | ||||||
|  | 		GroupValidator: func(s string) bool { | ||||||
|  | 			return true | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -266,6 +270,13 @@ func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) boo | ||||||
| 	return tp.ValidToken | 	return tp.ValidToken | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (tp *TestProvider) ValidateGroup(email string) bool { | ||||||
|  | 	if tp.GroupValidator != nil { | ||||||
|  | 		return tp.GroupValidator(email) | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestBasicAuthPassword(t *testing.T) { | func TestBasicAuthPassword(t *testing.T) { | ||||||
| 	providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		logger.Printf("%#v", r) | 		logger.Printf("%#v", r) | ||||||
|  | @ -791,6 +802,25 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | ||||||
| 	assert.Equal(t, "unauthorized request\n", string(bodyBytes)) | 	assert.Equal(t, "unauthorized request\n", string(bodyBytes)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestAuthOnlyEndpointUnauthorizedOnProviderGroupValidationFailure(t *testing.T) { | ||||||
|  | 	test := NewAuthOnlyEndpointTest() | ||||||
|  | 	startSession := &sessions.SessionState{ | ||||||
|  | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||||
|  | 	test.SaveSession(startSession) | ||||||
|  | 	provider := &TestProvider{ | ||||||
|  | 		ValidToken: true, | ||||||
|  | 		GroupValidator: func(s string) bool { | ||||||
|  | 			return false | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	test.proxy.provider = provider | ||||||
|  | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
|  | 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code) | ||||||
|  | 	bodyBytes, _ := ioutil.ReadAll(test.rw.Body) | ||||||
|  | 	assert.Equal(t, "unauthorized request\n", string(bodyBytes)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | ||||||
| 	var pcTest ProcessCookieTest | 	var pcTest ProcessCookieTest | ||||||
| 
 | 
 | ||||||
|  | @ -1168,69 +1198,80 @@ func TestGetJwtSession(t *testing.T) { | ||||||
| 	keyset := NoOpKeySet{} | 	keyset := NoOpKeySet{} | ||||||
| 	verifier := oidc.NewVerifier("https://issuer.example.com", keyset, | 	verifier := oidc.NewVerifier("https://issuer.example.com", keyset, | ||||||
| 		&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) | 		&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) | ||||||
| 	p := OAuthProxy{} |  | ||||||
| 	p.jwtBearerVerifiers = append(p.jwtBearerVerifiers, verifier) |  | ||||||
| 
 | 
 | ||||||
| 	req, _ := http.NewRequest("GET", "/", strings.NewReader("")) | 	test := NewAuthOnlyEndpointTest(func(opts *Options) { | ||||||
|  | 		opts.PassAuthorization = true | ||||||
|  | 		opts.SetAuthorization = true | ||||||
|  | 		opts.SetXAuthRequest = true | ||||||
|  | 		opts.SkipJwtBearerTokens = true | ||||||
|  | 		opts.jwtBearerVerifiers = append(opts.jwtBearerVerifiers, verifier) | ||||||
|  | 	}) | ||||||
|  | 	tp, _ := test.proxy.provider.(*TestProvider) | ||||||
|  | 	tp.GroupValidator = func(s string) bool { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	authHeader := fmt.Sprintf("Bearer %s", goodJwt) | 	authHeader := fmt.Sprintf("Bearer %s", goodJwt) | ||||||
| 	req.Header = map[string][]string{ | 	test.req.Header = map[string][]string{ | ||||||
| 		"Authorization": {authHeader}, | 		"Authorization": {authHeader}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Bearer
 | 	// Bearer
 | ||||||
| 	session, _ := p.GetJwtSession(req) | 	session, _ := test.proxy.GetJwtSession(test.req) | ||||||
| 	assert.Equal(t, session.User, "john@example.com") | 	assert.Equal(t, session.User, "john@example.com") | ||||||
| 	assert.Equal(t, session.Email, "john@example.com") | 	assert.Equal(t, session.Email, "john@example.com") | ||||||
| 	assert.Equal(t, session.ExpiresOn, time.Unix(1912151821, 0)) | 	assert.Equal(t, session.ExpiresOn, time.Unix(1912151821, 0)) | ||||||
| 	assert.Equal(t, session.IDToken, goodJwt) | 	assert.Equal(t, session.IDToken, goodJwt) | ||||||
| 
 | 
 | ||||||
| 	jwtProviderServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
| 		logger.Printf("%#v", r) | 	if test.rw.Code >= 400 { | ||||||
| 		var payload string | 		t.Fatalf("expected 3xx got %d", test.rw.Code) | ||||||
| 		payload = r.Header.Get("Authorization") |  | ||||||
| 		if payload == "" { |  | ||||||
| 			payload = "No Authorization header found." |  | ||||||
| 		} |  | ||||||
| 		w.WriteHeader(200) |  | ||||||
| 		w.Write([]byte(payload)) |  | ||||||
| 	})) |  | ||||||
| 
 |  | ||||||
| 	opts := NewOptions() |  | ||||||
| 	opts.Upstreams = append(opts.Upstreams, jwtProviderServer.URL) |  | ||||||
| 	opts.PassAuthorization = true |  | ||||||
| 	opts.SetAuthorization = true |  | ||||||
| 	opts.SetXAuthRequest = true |  | ||||||
| 	opts.CookieSecret = "0123456789abcdef0123" |  | ||||||
| 	opts.SkipJwtBearerTokens = true |  | ||||||
| 	opts.Validate() |  | ||||||
| 
 |  | ||||||
| 	// We can't actually use opts.Validate() because it will attempt to find a jwks URI
 |  | ||||||
| 	opts.jwtBearerVerifiers = append(opts.jwtBearerVerifiers, verifier) |  | ||||||
| 
 |  | ||||||
| 	providerURL, _ := url.Parse(jwtProviderServer.URL) |  | ||||||
| 	const emailAddress = "john@example.com" |  | ||||||
| 
 |  | ||||||
| 	opts.provider = NewTestProvider(providerURL, emailAddress) |  | ||||||
| 	jwtTestProxy := NewOAuthProxy(opts, func(email string) bool { |  | ||||||
| 		return email == emailAddress |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	rw := httptest.NewRecorder() |  | ||||||
| 	jwtTestProxy.ServeHTTP(rw, req) |  | ||||||
| 	if rw.Code >= 400 { |  | ||||||
| 		t.Fatalf("expected 3xx got %d", rw.Code) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check PassAuthorization, should overwrite Basic header
 | 	// Check PassAuthorization, should overwrite Basic header
 | ||||||
| 	assert.Equal(t, req.Header.Get("Authorization"), authHeader) | 	assert.Equal(t, test.req.Header.Get("Authorization"), authHeader) | ||||||
| 	assert.Equal(t, req.Header.Get("X-Forwarded-User"), "john@example.com") | 	assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "john@example.com") | ||||||
| 	assert.Equal(t, req.Header.Get("X-Forwarded-Email"), "john@example.com") | 	assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com") | ||||||
| 
 | 
 | ||||||
| 	// SetAuthorization and SetXAuthRequest
 | 	// SetAuthorization and SetXAuthRequest
 | ||||||
| 	assert.Equal(t, rw.Header().Get("Authorization"), authHeader) | 	assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader) | ||||||
| 	assert.Equal(t, rw.Header().Get("X-Auth-Request-User"), "john@example.com") | 	assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "john@example.com") | ||||||
| 	assert.Equal(t, rw.Header().Get("X-Auth-Request-Email"), "john@example.com") | 	assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com") | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
|  | func TestJwtUnauthorizedOnGroupValidationFailure(t *testing.T) { | ||||||
|  | 	goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + | ||||||
|  | 		"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" + | ||||||
|  | 		"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" + | ||||||
|  | 		"E1LCJleHAiOjE5MTIxNTE4MjF9." + | ||||||
|  | 		"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" + | ||||||
|  | 		"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8" | ||||||
|  | 
 | ||||||
|  | 	keyset := NoOpKeySet{} | ||||||
|  | 	verifier := oidc.NewVerifier("https://issuer.example.com", keyset, | ||||||
|  | 		&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) | ||||||
|  | 
 | ||||||
|  | 	test := NewAuthOnlyEndpointTest(func(opts *Options) { | ||||||
|  | 		opts.PassAuthorization = true | ||||||
|  | 		opts.SetAuthorization = true | ||||||
|  | 		opts.SetXAuthRequest = true | ||||||
|  | 		opts.SkipJwtBearerTokens = true | ||||||
|  | 		opts.jwtBearerVerifiers = append(opts.jwtBearerVerifiers, verifier) | ||||||
|  | 	}) | ||||||
|  | 	tp, _ := test.proxy.provider.(*TestProvider) | ||||||
|  | 	// Verify ValidateGroup fails JWT authorization
 | ||||||
|  | 	tp.GroupValidator = func(s string) bool { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	authHeader := fmt.Sprintf("Bearer %s", goodJwt) | ||||||
|  | 	test.req.Header = map[string][]string{ | ||||||
|  | 		"Authorization": {authHeader}, | ||||||
|  | 	} | ||||||
|  | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
|  | 	if test.rw.Code != http.StatusUnauthorized { | ||||||
|  | 		t.Fatalf("expected 401 got %d", test.rw.Code) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestFindJwtBearerToken(t *testing.T) { | func TestFindJwtBearerToken(t *testing.T) { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue