diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index b159df09..2ac4e9aa 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -212,7 +212,7 @@ Provider specific options can be found on their respective subpages. | flag: `--signature-key`
toml: `signature_key` | string | GAP-Signature request signature key (algorithm:secretkey) | | | flag: `--skip-auth-preflight`
toml: `skip_auth_preflight` | bool | will skip authentication for OPTIONS requests | false | | flag: `--skip-auth-regex`
toml: `skip_auth_regex` | string \| list | (DEPRECATED for `--skip-auth-route`) bypass authentication for requests paths that match (may be given multiple times) | | -| flag: `--skip-auth-route`
toml: `skip_auth_routes` | string \| list | bypass authentication for requests that match the method & path. Format: method=path_regex OR method!=path_regex. For all methods: path_regex OR !=path_regex | | +| flag: `--skip-auth-route`
toml: `skip_auth_routes` | string \| list | bypass authentication for requests that match the method, path, or domain. Format: method=path_regex OR method!=path_regex OR domain=domain_regex. For all methods: path_regex OR !=path_regex. Domain matching uses the Host header (or X-Forwarded-Host if behind a reverse proxy). | | | flag: `--skip-jwt-bearer-tokens`
toml: `skip_jwt_bearer_tokens` | bool | will skip requests that have verified JWT bearer tokens (the token must have [`aud`](https://en.wikipedia.org/wiki/JSON_Web_Token#Standard_fields) that matches this client id or one of the extras from `extra-jwt-issuers`) | false | | flag: `--skip-provider-button`
toml: `skip_provider_button` | bool | will skip sign-in-page to directly reach the next step: oauth/start | false | | flag: `--ssl-insecure-skip-verify`
toml: `ssl_insecure_skip_verify` | bool | skip validation of certificates presented when using HTTPS providers | false | diff --git a/oauthproxy.go b/oauthproxy.go index 7526d641..297888f2 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -67,11 +67,12 @@ var ( staticFiles embed.FS ) -// allowedRoute manages method + path based allowlists +// allowedRoute manages method + path + domain based allowlists type allowedRoute struct { - method string - negate bool - pathRegex *regexp.Regexp + method string + negate bool + pathRegex *regexp.Regexp + domainRegex *regexp.Regexp } type apiRoute struct { @@ -464,7 +465,7 @@ func buildProviderName(p providers.Provider, override string) string { // buildRoutesAllowlist builds an []allowedRoute list from either the legacy // SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option -// (method=path support) +// (method=path and domain=domain_regex support) func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { routes := make([]allowedRoute, 0, len(opts.SkipAuthRegex)+len(opts.SkipAuthRoutes)) @@ -492,6 +493,23 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { method = "" path = parts[0] } else { + prefix := strings.ToLower(parts[0]) + // Check if this is a domain-based route + if prefix == "domain" { + // Domain-based route: domain=regex + domainPattern := parts[1] + compiledDomainRegex, err := regexp.Compile(domainPattern) + if err != nil { + return nil, err + } + logger.Printf("Skipping auth - Domain: %s", domainPattern) + routes = append(routes, allowedRoute{ + method: "", + domainRegex: compiledDomainRegex, + }) + continue + } + // Method-based route: method=path method = strings.ToUpper(parts[0]) path = parts[1] } @@ -580,6 +598,11 @@ func isAllowedMethod(req *http.Request, route allowedRoute) bool { } func isAllowedPath(req *http.Request, route allowedRoute) bool { + // If there's no path regex, consider the path as allowed + if route.pathRegex == nil { + return true + } + matches := route.pathRegex.MatchString(requestutil.GetRequestPath(req)) if route.negate { @@ -589,10 +612,20 @@ func isAllowedPath(req *http.Request, route allowedRoute) bool { return matches } -// IsAllowedRoute is used to check if the request method & path is allowed without auth +func isAllowedDomain(req *http.Request, route allowedRoute) bool { + // If there's no domain regex, consider the domain as allowed + if route.domainRegex == nil { + return true + } + + host := requestutil.GetRequestHost(req) + return route.domainRegex.MatchString(host) +} + +// IsAllowedRoute is used to check if the request method, path, and domain are allowed without auth func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { for _, route := range p.allowedRoutes { - if isAllowedMethod(req, route) && isAllowedPath(req, route) { + if isAllowedMethod(req, route) && isAllowedPath(req, route) && isAllowedDomain(req, route) { return true } } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 488b8cea..c087cff3 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -2249,9 +2249,10 @@ func TestTrustedIPs(t *testing.T) { func Test_buildRoutesAllowlist(t *testing.T) { type expectedAllowedRoute struct { - method string - negate bool - regexString string + method string + negate bool + regexString string + domainRegexString string } testCases := []struct { @@ -2375,6 +2376,49 @@ func Test_buildRoutesAllowlist(t *testing.T) { }, shouldError: false, }, + { + name: "Domain-based routes", + skipAuthRegex: []string{}, + skipAuthRoutes: []string{ + "domain=example\\.com", + "domain=.*\\.subdomain\\.com", + }, + expectedRoutes: []expectedAllowedRoute{ + { + method: "", + domainRegexString: "example\\.com", + }, + { + method: "", + domainRegexString: ".*\\.subdomain\\.com", + }, + }, + shouldError: false, + }, + { + name: "Mixed method and domain routes", + skipAuthRegex: []string{}, + skipAuthRoutes: []string{ + "GET=^/api/v1", + "domain=api\\.example\\.com", + "POST=^/webhook", + }, + expectedRoutes: []expectedAllowedRoute{ + { + method: "GET", + regexString: "^/api/v1", + }, + { + method: "", + domainRegexString: "api\\.example\\.com", + }, + { + method: "POST", + regexString: "^/webhook", + }, + }, + shouldError: false, + }, { name: "Invalid skipAuthRegex entry", skipAuthRegex: []string{ @@ -2417,7 +2461,22 @@ func Test_buildRoutesAllowlist(t *testing.T) { assert.Greater(t, len(tc.expectedRoutes), i) assert.Equal(t, route.method, tc.expectedRoutes[i].method) assert.Equal(t, route.negate, tc.expectedRoutes[i].negate) - assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString) + + // Check path regex if expected + if tc.expectedRoutes[i].regexString != "" { + assert.NotNil(t, route.pathRegex) + assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString) + } else { + assert.Nil(t, route.pathRegex) + } + + // Check domain regex if expected + if tc.expectedRoutes[i].domainRegexString != "" { + assert.NotNil(t, route.domainRegex) + assert.Equal(t, route.domainRegex.String(), tc.expectedRoutes[i].domainRegexString) + } else { + assert.Nil(t, route.domainRegex) + } } }) } @@ -2635,6 +2694,118 @@ func TestAllowedRequest(t *testing.T) { } } +func TestAllowedRequestWithDomain(t *testing.T) { + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte("Allowed Request")) + if err != nil { + t.Fatal(err) + } + })) + t.Cleanup(upstreamServer.Close) + + opts := baseTestOptions() + opts.UpstreamServers = options.UpstreamConfig{ + Upstreams: []options.Upstream{ + { + ID: upstreamServer.URL, + Path: "/", + URI: upstreamServer.URL, + }, + }, + } + opts.SkipAuthRoutes = []string{ + "domain=api\\.example\\.com", + "domain=.*\\.subdomain\\.com", + "GET=^/api/public", + } + err := validation.Validate(opts) + assert.NoError(t, err) + proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + method string + url string + host string + allowed bool + }{ + { + name: "Domain allowed - exact match", + method: "GET", + url: "/any/path", + host: "api.example.com", + allowed: true, + }, + { + name: "Domain allowed - subdomain match", + method: "POST", + url: "/any/path", + host: "test.subdomain.com", + allowed: true, + }, + { + name: "Domain allowed - another subdomain", + method: "GET", + url: "/different/path", + host: "app.subdomain.com", + allowed: true, + }, + { + name: "Domain denied - wrong domain", + method: "GET", + url: "/any/path", + host: "other.example.com", + allowed: false, + }, + { + name: "Domain denied - no host", + method: "GET", + url: "/any/path", + host: "", + allowed: false, + }, + { + name: "Path route allowed regardless of domain", + method: "GET", + url: "/api/public", + host: "different.com", + allowed: true, + }, + { + name: "Path route with allowed domain", + method: "GET", + url: "/api/public", + host: "api.example.com", + allowed: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(tc.method, tc.url, nil) + assert.NoError(t, err) + if tc.host != "" { + req.Host = tc.host + } + assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req)) + + rw := httptest.NewRecorder() + proxy.ServeHTTP(rw, req) + + if tc.allowed { + assert.Equal(t, 200, rw.Code) + assert.Equal(t, "Allowed Request", rw.Body.String()) + } else { + assert.Equal(t, 403, rw.Code) + } + }) + } +} + func TestAllowedRequestWithForwardedUriHeader(t *testing.T) { upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) diff --git a/pkg/validation/allowlist.go b/pkg/validation/allowlist.go index 12f67aa7..4b44dac8 100644 --- a/pkg/validation/allowlist.go +++ b/pkg/validation/allowlist.go @@ -27,7 +27,7 @@ func validateAllowlists(o *options.Options) []string { return msgs } -// validateAuthRoutes validates method=path routes passed with options.SkipAuthRoutes +// validateAuthRoutes validates method=path routes and domain=domain_regex passed with options.SkipAuthRoutes func validateAuthRoutes(o *options.Options) []string { msgs := []string{} for _, route := range o.SkipAuthRoutes { @@ -36,6 +36,7 @@ func validateAuthRoutes(o *options.Options) []string { if len(parts) == 1 { regex = parts[0] } else { + // For method or domain-based routes, validate the regex regex = parts[1] } _, err := regexp.Compile(regex)