Add tests for skip auth functionality
This commit is contained in:
		
							parent
							
								
									183cb124a4
								
							
						
					
					
						commit
						cfd3de807c
					
				|  | @ -286,7 +286,7 @@ func buildSignInMessage(opts *options.Options) string { | |||
| // SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option
 | ||||
| // (method=path support)
 | ||||
| func buildRoutesAllowlist(opts *options.Options) ([]*allowedRoute, error) { | ||||
| 	var routes []*allowedRoute | ||||
| 	routes := make([]*allowedRoute, 0, len(opts.SkipAuthRegex)+len(opts.SkipAuthRoutes)) | ||||
| 
 | ||||
| 	for _, path := range opts.SkipAuthRegex { | ||||
| 		compiledRegex, err := regexp.Compile(path) | ||||
|  |  | |||
|  | @ -1482,28 +1482,28 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestAuthSkippedForPreflightRequests(t *testing.T) { | ||||
| 	upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(200) | ||||
| 		_, err := w.Write([]byte("response")) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	})) | ||||
| 	t.Cleanup(upstream.Close) | ||||
| 	t.Cleanup(upstreamServer.Close) | ||||
| 
 | ||||
| 	opts := baseTestOptions() | ||||
| 	opts.UpstreamServers = options.Upstreams{ | ||||
| 		{ | ||||
| 			ID:   upstream.URL, | ||||
| 			ID:   upstreamServer.URL, | ||||
| 			Path: "/", | ||||
| 			URI:  upstream.URL, | ||||
| 			URI:  upstreamServer.URL, | ||||
| 		}, | ||||
| 	} | ||||
| 	opts.SkipAuthPreflight = true | ||||
| 	err := validation.Validate(opts) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	upstreamURL, _ := url.Parse(upstream.URL) | ||||
| 	upstreamURL, _ := url.Parse(upstreamServer.URL) | ||||
| 	opts.SetProvider(NewTestProvider(upstreamURL, "")) | ||||
| 
 | ||||
| 	proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) | ||||
|  | @ -1561,17 +1561,17 @@ func NewSignatureTest() (*SignatureTest, error) { | |||
| 	opts.EmailDomains = []string{"acm.org"} | ||||
| 
 | ||||
| 	authenticator := &SignatureAuthenticator{} | ||||
| 	upstream := httptest.NewServer( | ||||
| 	upstreamServer := httptest.NewServer( | ||||
| 		http.HandlerFunc(authenticator.Authenticate)) | ||||
| 	upstreamURL, err := url.Parse(upstream.URL) | ||||
| 	upstreamURL, err := url.Parse(upstreamServer.URL) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	opts.UpstreamServers = options.Upstreams{ | ||||
| 		{ | ||||
| 			ID:   upstream.URL, | ||||
| 			ID:   upstreamServer.URL, | ||||
| 			Path: "/", | ||||
| 			URI:  upstream.URL, | ||||
| 			URI:  upstreamServer.URL, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
|  | @ -1590,7 +1590,7 @@ func NewSignatureTest() (*SignatureTest, error) { | |||
| 
 | ||||
| 	return &SignatureTest{ | ||||
| 		opts, | ||||
| 		upstream, | ||||
| 		upstreamServer, | ||||
| 		upstreamURL.Host, | ||||
| 		provider, | ||||
| 		make(http.Header), | ||||
|  | @ -1974,20 +1974,20 @@ func Test_prepareNoCache(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func Test_noCacheHeaders(t *testing.T) { | ||||
| 	upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		_, err := w.Write([]byte("upstream")) | ||||
| 		if err != nil { | ||||
| 			t.Error(err) | ||||
| 		} | ||||
| 	})) | ||||
| 	t.Cleanup(upstream.Close) | ||||
| 	t.Cleanup(upstreamServer.Close) | ||||
| 
 | ||||
| 	opts := baseTestOptions() | ||||
| 	opts.UpstreamServers = options.Upstreams{ | ||||
| 		{ | ||||
| 			ID:   upstream.URL, | ||||
| 			ID:   upstreamServer.URL, | ||||
| 			Path: "/", | ||||
| 			URI:  upstream.URL, | ||||
| 			URI:  upstreamServer.URL, | ||||
| 		}, | ||||
| 	} | ||||
| 	opts.SkipAuthRegex = []string{".*"} | ||||
|  | @ -2224,7 +2224,8 @@ func TestTrustedIPs(t *testing.T) { | |||
| 			opts.TrustedIPs = tt.trustedIPs | ||||
| 			opts.ReverseProxy = tt.reverseProxy | ||||
| 			opts.RealClientIPHeader = tt.realClientIPHeader | ||||
| 			validation.Validate(opts) | ||||
| 			err := validation.Validate(opts) | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) | ||||
| 			assert.NoError(t, err) | ||||
|  | @ -2240,6 +2241,237 @@ func TestTrustedIPs(t *testing.T) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Test_buildRoutesAllowlist(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		name            string | ||||
| 		skipAuthRegex   []string | ||||
| 		skipAuthRoutes  []string | ||||
| 		expectedMethods []string | ||||
| 		expectedRegexes []string | ||||
| 		shouldError     bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:            "No skip auth configured", | ||||
| 			skipAuthRegex:   []string{}, | ||||
| 			skipAuthRoutes:  []string{}, | ||||
| 			expectedMethods: []string{}, | ||||
| 			expectedRegexes: []string{}, | ||||
| 			shouldError:     false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Only skipAuthRegex configured", | ||||
| 			skipAuthRegex: []string{ | ||||
| 				"^/foo/bar", | ||||
| 				"^/baz/[0-9]+/thing", | ||||
| 			}, | ||||
| 			skipAuthRoutes: []string{}, | ||||
| 			expectedMethods: []string{ | ||||
| 				"", | ||||
| 				"", | ||||
| 			}, | ||||
| 			expectedRegexes: []string{ | ||||
| 				"^/foo/bar", | ||||
| 				"^/baz/[0-9]+/thing", | ||||
| 			}, | ||||
| 			shouldError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "Only skipAuthRoutes configured", | ||||
| 			skipAuthRegex: []string{}, | ||||
| 			skipAuthRoutes: []string{ | ||||
| 				"GET=^/foo/bar", | ||||
| 				"POST=^/baz/[0-9]+/thing", | ||||
| 				"^/all/methods$", | ||||
| 				"WEIRD=^/methods/are/allowed", | ||||
| 				"PATCH=/second/equals?are=handled&just=fine", | ||||
| 			}, | ||||
| 			expectedMethods: []string{ | ||||
| 				"GET", | ||||
| 				"POST", | ||||
| 				"", | ||||
| 				"WEIRD", | ||||
| 				"PATCH", | ||||
| 			}, | ||||
| 			expectedRegexes: []string{ | ||||
| 				"^/foo/bar", | ||||
| 				"^/baz/[0-9]+/thing", | ||||
| 				"^/all/methods$", | ||||
| 				"^/methods/are/allowed", | ||||
| 				"/second/equals?are=handled&just=fine", | ||||
| 			}, | ||||
| 			shouldError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Both skipAuthRegexes and skipAuthRoutes configured", | ||||
| 			skipAuthRegex: []string{ | ||||
| 				"^/foo/bar/regex", | ||||
| 				"^/baz/[0-9]+/thing/regex", | ||||
| 			}, | ||||
| 			skipAuthRoutes: []string{ | ||||
| 				"GET=^/foo/bar", | ||||
| 				"POST=^/baz/[0-9]+/thing", | ||||
| 				"^/all/methods$", | ||||
| 			}, | ||||
| 			expectedMethods: []string{ | ||||
| 				"", | ||||
| 				"", | ||||
| 				"GET", | ||||
| 				"POST", | ||||
| 				"", | ||||
| 			}, | ||||
| 			expectedRegexes: []string{ | ||||
| 				"^/foo/bar/regex", | ||||
| 				"^/baz/[0-9]+/thing/regex", | ||||
| 				"^/foo/bar", | ||||
| 				"^/baz/[0-9]+/thing", | ||||
| 				"^/all/methods$", | ||||
| 			}, | ||||
| 			shouldError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Invalid skipAuthRegex entry", | ||||
| 			skipAuthRegex: []string{ | ||||
| 				"^/foo/bar", | ||||
| 				"^/baz/[0-9]+/thing", | ||||
| 				"(bad[regex", | ||||
| 			}, | ||||
| 			skipAuthRoutes:  []string{}, | ||||
| 			expectedMethods: []string{}, | ||||
| 			expectedRegexes: []string{}, | ||||
| 			shouldError:     true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "Invalid skipAuthRoutes entry", | ||||
| 			skipAuthRegex: []string{}, | ||||
| 			skipAuthRoutes: []string{ | ||||
| 				"GET=^/foo/bar", | ||||
| 				"POST=^/baz/[0-9]+/thing", | ||||
| 				"^/all/methods$", | ||||
| 				"PUT=(bad[regex", | ||||
| 			}, | ||||
| 			expectedMethods: []string{}, | ||||
| 			expectedRegexes: []string{}, | ||||
| 			shouldError:     true, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			opts := &options.Options{ | ||||
| 				SkipAuthRegex:  tc.skipAuthRegex, | ||||
| 				SkipAuthRoutes: tc.skipAuthRoutes, | ||||
| 			} | ||||
| 			routes, err := buildRoutesAllowlist(opts) | ||||
| 			if tc.shouldError { | ||||
| 				assert.Error(t, err) | ||||
| 				return | ||||
| 			} else { | ||||
| 				assert.NoError(t, err) | ||||
| 			} | ||||
| 			for i, route := range routes { | ||||
| 				assert.Greater(t, len(tc.expectedMethods), i) | ||||
| 				assert.Equal(t, route.method, tc.expectedMethods[i]) | ||||
| 				assert.Greater(t, len(tc.expectedRegexes), i) | ||||
| 				assert.Equal(t, route.pathRegex.String(), tc.expectedRegexes[i]) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestAllowedRequest(t *testing.T) { | ||||
| 	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(200) | ||||
| 		_, err := w.Write([]byte("Allowed Request")) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	})) | ||||
| 	t.Cleanup(upstreamServer.Close) | ||||
| 
 | ||||
| 	opts := baseTestOptions() | ||||
| 	opts.UpstreamServers = options.Upstreams{ | ||||
| 		{ | ||||
| 			ID:   upstreamServer.URL, | ||||
| 			Path: "/", | ||||
| 			URI:  upstreamServer.URL, | ||||
| 		}, | ||||
| 	} | ||||
| 	opts.SkipAuthRegex = []string{ | ||||
| 		"^/skip/auth/regex$", | ||||
| 	} | ||||
| 	opts.SkipAuthRoutes = []string{ | ||||
| 		"GET=^/skip/auth/routes/get", | ||||
| 	} | ||||
| 	err := validation.Validate(opts) | ||||
| 	assert.NoError(t, err) | ||||
| 	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		name    string | ||||
| 		method  string | ||||
| 		url     string | ||||
| 		allowed bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:    "Regex GET allowed", | ||||
| 			method:  "GET", | ||||
| 			url:     "/skip/auth/regex", | ||||
| 			allowed: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "Regex POST allowed ", | ||||
| 			method:  "POST", | ||||
| 			url:     "/skip/auth/regex", | ||||
| 			allowed: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "Regex denied", | ||||
| 			method:  "GET", | ||||
| 			url:     "/wrong/denied", | ||||
| 			allowed: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "Route allowed", | ||||
| 			method:  "GET", | ||||
| 			url:     "/skip/auth/routes/get", | ||||
| 			allowed: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "Route denied with wrong method", | ||||
| 			method:  "PATCH", | ||||
| 			url:     "/skip/auth/routes/get", | ||||
| 			allowed: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "Route denied with wrong path", | ||||
| 			method:  "GET", | ||||
| 			url:     "/skip/auth/routes/wrong/path", | ||||
| 			allowed: false, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			req, err := http.NewRequest(tc.method, tc.url, nil) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req)) | ||||
| 
 | ||||
| 			rw := httptest.NewRecorder() | ||||
| 			proxy.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 			if tc.allowed { | ||||
| 				assert.Equal(t, 200, rw.Code) | ||||
| 				assert.Equal(t, "Allowed Request", rw.Body.String()) | ||||
| 			} else { | ||||
| 				assert.Equal(t, 403, rw.Code) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestProxyAllowedGroups(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name               string | ||||
|  | @ -2265,18 +2497,18 @@ func TestProxyAllowedGroups(t *testing.T) { | |||
| 				CreatedAt:   &created, | ||||
| 			} | ||||
| 
 | ||||
| 			upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 				w.WriteHeader(200) | ||||
| 			})) | ||||
| 			t.Cleanup(upstream.Close) | ||||
| 			t.Cleanup(upstreamServer.Close) | ||||
| 
 | ||||
| 			test, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { | ||||
| 				opts.AllowedGroups = tt.allowedGroups | ||||
| 				opts.UpstreamServers = options.Upstreams{ | ||||
| 					{ | ||||
| 						ID:   upstream.URL, | ||||
| 						ID:   upstreamServer.URL, | ||||
| 						Path: "/", | ||||
| 						URI:  upstream.URL, | ||||
| 						URI:  upstreamServer.URL, | ||||
| 					}, | ||||
| 				} | ||||
| 			}) | ||||
|  | @ -2287,7 +2519,8 @@ func TestProxyAllowedGroups(t *testing.T) { | |||
| 			test.req, _ = http.NewRequest("GET", "/", nil) | ||||
| 
 | ||||
| 			test.req.Header.Add("accept", applicationJSON) | ||||
| 			test.SaveSession(session) | ||||
| 			err = test.SaveSession(session) | ||||
| 			assert.NoError(t, err) | ||||
| 			test.proxy.ServeHTTP(test.rw, test.req) | ||||
| 
 | ||||
| 			if tt.expectUnauthorized { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue