WIP
This commit is contained in:
		
							parent
							
								
									374a676c9d
								
							
						
					
					
						commit
						0dbda5dfac
					
				
							
								
								
									
										128
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										128
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -11,7 +11,6 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/signal" | 	"os/signal" | ||||||
| 	"regexp" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"syscall" | 	"syscall" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -61,12 +60,6 @@ var ( | ||||||
| 	ErrAccessDenied = errors.New("access denied") | 	ErrAccessDenied = errors.New("access denied") | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // allowedRoute manages method + path based allowlists
 |  | ||||||
| type allowedRoute struct { |  | ||||||
| 	method    string |  | ||||||
| 	pathRegex *regexp.Regexp |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // OAuthProxy is the main authentication proxy
 | // OAuthProxy is the main authentication proxy
 | ||||||
| type OAuthProxy struct { | type OAuthProxy struct { | ||||||
| 	CookieOptions *options.Cookie | 	CookieOptions *options.Cookie | ||||||
|  | @ -74,7 +67,6 @@ type OAuthProxy struct { | ||||||
| 
 | 
 | ||||||
| 	SignInPath string | 	SignInPath string | ||||||
| 
 | 
 | ||||||
| 	allowedRoutes       []allowedRoute |  | ||||||
| 	redirectURL         *url.URL // the url to receive requests at
 | 	redirectURL         *url.URL // the url to receive requests at
 | ||||||
| 	whitelistDomains    []string | 	whitelistDomains    []string | ||||||
| 	provider            providers.Provider | 	provider            providers.Provider | ||||||
|  | @ -83,11 +75,9 @@ type OAuthProxy struct { | ||||||
| 	basicAuthValidator  basic.Validator | 	basicAuthValidator  basic.Validator | ||||||
| 	basicAuthGroups     []string | 	basicAuthGroups     []string | ||||||
| 	SkipProviderButton  bool | 	SkipProviderButton  bool | ||||||
| 	skipAuthPreflight   bool |  | ||||||
| 	skipJwtBearerTokens bool | 	skipJwtBearerTokens bool | ||||||
| 	forceJSONErrors     bool | 	forceJSONErrors     bool | ||||||
| 	realClientIPParser  ipapi.RealClientIPParser | 	realClientIPParser  ipapi.RealClientIPParser | ||||||
| 	trustedIPs          *ip.NetSet |  | ||||||
| 
 | 
 | ||||||
| 	sessionChain      alice.Chain | 	sessionChain      alice.Chain | ||||||
| 	headersChain      alice.Chain | 	headersChain      alice.Chain | ||||||
|  | @ -161,21 +151,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domains:%s path:%s samesite:%s refresh:%s", opts.Cookie.Name, opts.Cookie.Secure, opts.Cookie.HTTPOnly, opts.Cookie.Expire, strings.Join(opts.Cookie.Domains, ","), opts.Cookie.Path, opts.Cookie.SameSite, refresh) | 	logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domains:%s path:%s samesite:%s refresh:%s", opts.Cookie.Name, opts.Cookie.Secure, opts.Cookie.HTTPOnly, opts.Cookie.Expire, strings.Join(opts.Cookie.Domains, ","), opts.Cookie.Path, opts.Cookie.SameSite, refresh) | ||||||
| 
 | 
 | ||||||
| 	trustedIPs := ip.NewNetSet() | 	preAuthChain, err := buildPreAuthChain(opts, pageWriter) | ||||||
| 	for _, ipStr := range opts.TrustedIPs { |  | ||||||
| 		if ipNet := ip.ParseIPNet(ipStr); ipNet != nil { |  | ||||||
| 			trustedIPs.AddIPNet(*ipNet) |  | ||||||
| 		} else { |  | ||||||
| 			return nil, fmt.Errorf("could not parse IP network (%s)", ipStr) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	allowedRoutes, err := buildRoutesAllowlist(opts) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	preAuthChain, err := buildPreAuthChain(opts) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("could not build pre-auth chain: %v", err) | 		return nil, fmt.Errorf("could not build pre-auth chain: %v", err) | ||||||
| 	} | 	} | ||||||
|  | @ -201,14 +177,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		provider:            provider, | 		provider:            provider, | ||||||
| 		sessionStore:        sessionStore, | 		sessionStore:        sessionStore, | ||||||
| 		redirectURL:         redirectURL, | 		redirectURL:         redirectURL, | ||||||
| 		allowedRoutes:       allowedRoutes, |  | ||||||
| 		whitelistDomains:    opts.WhitelistDomains, | 		whitelistDomains:    opts.WhitelistDomains, | ||||||
| 		skipAuthPreflight:   opts.SkipAuthPreflight, |  | ||||||
| 		skipJwtBearerTokens: opts.SkipJwtBearerTokens, | 		skipJwtBearerTokens: opts.SkipJwtBearerTokens, | ||||||
| 		realClientIPParser:  opts.GetRealClientIPParser(), | 		realClientIPParser:  opts.GetRealClientIPParser(), | ||||||
| 		SkipProviderButton:  opts.SkipProviderButton, | 		SkipProviderButton:  opts.SkipProviderButton, | ||||||
| 		forceJSONErrors:     opts.ForceJSONErrors, | 		forceJSONErrors:     opts.ForceJSONErrors, | ||||||
| 		trustedIPs:          trustedIPs, |  | ||||||
| 
 | 
 | ||||||
| 		basicAuthValidator: basicAuthValidator, | 		basicAuthValidator: basicAuthValidator, | ||||||
| 		basicAuthGroups:    opts.HtpasswdUserGroups, | 		basicAuthGroups:    opts.HtpasswdUserGroups, | ||||||
|  | @ -316,7 +289,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { | ||||||
| // buildPreAuthChain constructs a chain that should process every request before
 | // buildPreAuthChain constructs a chain that should process every request before
 | ||||||
| // the OAuth2 Proxy authentication logic kicks in.
 | // the OAuth2 Proxy authentication logic kicks in.
 | ||||||
| // For example forcing HTTPS or health checks.
 | // For example forcing HTTPS or health checks.
 | ||||||
| func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | func buildPreAuthChain(opts *options.Options, pageWriter pagewriter.Writer) (alice.Chain, error) { | ||||||
| 	chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) | 	chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) | ||||||
| 
 | 
 | ||||||
| 	if opts.ForceHTTPS { | 	if opts.ForceHTTPS { | ||||||
|  | @ -351,6 +324,22 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | ||||||
| 
 | 
 | ||||||
| 	chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) | 	chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) | ||||||
| 
 | 
 | ||||||
|  | 	requestAuthorization, err := middleware.NewRequestAuthorization(pageWriter, opts.Authorization.RequestRules, func(req *http.Request) net.IP { | ||||||
|  | 		if opts.GetRealClientIPParser() == nil { | ||||||
|  | 			host, _ := util.SplitHostPort(req.RemoteAddr) | ||||||
|  | 			return net.ParseIP(host) | ||||||
|  | 		} | ||||||
|  | 		ip, err := opts.GetRealClientIPParser().GetRealClientIP(req.Header) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 		return ip | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return alice.Chain{}, fmt.Errorf("error initialising request authorization middleware: %w", err) | ||||||
|  | 	} | ||||||
|  | 	chain = chain.Append(requestAuthorization) | ||||||
|  | 
 | ||||||
| 	return chain, nil | 	return chain, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -423,53 +412,6 @@ func buildProviderName(p providers.Provider, override string) string { | ||||||
| 	return p.Data().ProviderName | 	return p.Data().ProviderName | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // buildRoutesAllowlist builds an []allowedRoute  list from either the legacy
 |  | ||||||
| // SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option
 |  | ||||||
| // (method=path support)
 |  | ||||||
| func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { |  | ||||||
| 	routes := make([]allowedRoute, 0, len(opts.SkipAuthRegex)+len(opts.SkipAuthRoutes)) |  | ||||||
| 
 |  | ||||||
| 	for _, path := range opts.SkipAuthRegex { |  | ||||||
| 		compiledRegex, err := regexp.Compile(path) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		logger.Printf("Skipping auth - Method: ALL | Path: %s", path) |  | ||||||
| 		routes = append(routes, allowedRoute{ |  | ||||||
| 			method:    "", |  | ||||||
| 			pathRegex: compiledRegex, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, methodPath := range opts.SkipAuthRoutes { |  | ||||||
| 		var ( |  | ||||||
| 			method string |  | ||||||
| 			path   string |  | ||||||
| 		) |  | ||||||
| 
 |  | ||||||
| 		parts := strings.SplitN(methodPath, "=", 2) |  | ||||||
| 		if len(parts) == 1 { |  | ||||||
| 			method = "" |  | ||||||
| 			path = parts[0] |  | ||||||
| 		} else { |  | ||||||
| 			method = strings.ToUpper(parts[0]) |  | ||||||
| 			path = parts[1] |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		compiledRegex, err := regexp.Compile(path) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 		logger.Printf("Skipping auth - Method: %s | Path: %s", method, path) |  | ||||||
| 		routes = append(routes, allowedRoute{ |  | ||||||
| 			method:    method, |  | ||||||
| 			pathRegex: compiledRegex, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return routes, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | ||||||
| // stored in the user's session
 | // stored in the user's session
 | ||||||
| func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { | func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { | ||||||
|  | @ -512,38 +454,8 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i | ||||||
| 
 | 
 | ||||||
| // IsAllowedRequest is used to check if auth should be skipped for this request
 | // IsAllowedRequest is used to check if auth should be skipped for this request
 | ||||||
| func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { | func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { | ||||||
| 	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" | 	scope := middlewareapi.GetRequestScope(req) | ||||||
| 	return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) | 	return scope.Authorization.Policy == middlewareapi.AllowPolicy | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // IsAllowedRoute is used to check if the request method & path is allowed without auth
 |  | ||||||
| func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { |  | ||||||
| 	for _, route := range p.allowedRoutes { |  | ||||||
| 		if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // isTrustedIP is used to check if a request comes from a trusted client IP address.
 |  | ||||||
| func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { |  | ||||||
| 	if p.trustedIPs == nil { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) |  | ||||||
| 		// Possibly spoofed X-Real-IP header
 |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if remoteAddr == nil { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return p.trustedIPs.Has(remoteAddr) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SignInPage writes the sign in template to the response
 | // SignInPage writes the sign in template to the response
 | ||||||
|  |  | ||||||
|  | @ -1358,7 +1358,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	opts.SkipAuthPreflight = true | 	opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||||
|  | 		{ | ||||||
|  | 			ID:      "skip-auth-preflight", | ||||||
|  | 			Methods: []string{http.MethodOptions}, | ||||||
|  | 			Policy:  options.AllowPolicy, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
| 	err := validation.Validate(opts) | 	err := validation.Validate(opts) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
|  | @ -1889,7 +1895,14 @@ func Test_noCacheHeaders(t *testing.T) { | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	opts.SkipAuthRegex = []string{".*"} | 	opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||||
|  | 		{ | ||||||
|  | 			ID:     "wildcard", | ||||||
|  | 			Path:   ".*", | ||||||
|  | 			Policy: options.AllowPolicy, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	err := validation.Validate(opts) | 	err := validation.Validate(opts) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) | 	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) | ||||||
|  | @ -2161,7 +2174,16 @@ func TestTrustedIPs(t *testing.T) { | ||||||
| 					}, | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			} | 			} | ||||||
| 			opts.TrustedIPs = tt.trustedIPs | 			if len(tt.trustedIPs) > 0 { | ||||||
|  | 				opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||||
|  | 					{ | ||||||
|  | 						ID:     "trusted-ips", | ||||||
|  | 						IPs:    tt.trustedIPs, | ||||||
|  | 						Policy: options.AllowPolicy, | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
| 			opts.ReverseProxy = tt.reverseProxy | 			opts.ReverseProxy = tt.reverseProxy | ||||||
| 			opts.RealClientIPHeader = tt.realClientIPHeader | 			opts.RealClientIPHeader = tt.realClientIPHeader | ||||||
| 			err := validation.Validate(opts) | 			err := validation.Validate(opts) | ||||||
|  | @ -2181,160 +2203,160 @@ func TestTrustedIPs(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Test_buildRoutesAllowlist(t *testing.T) { | // func Test_buildRoutesAllowlist(t *testing.T) {
 | ||||||
| 	type expectedAllowedRoute struct { | // 	type expectedAllowedRoute struct {
 | ||||||
| 		method      string | // 		method      string
 | ||||||
| 		regexString string | // 		regexString string
 | ||||||
| 	} | // 	}
 | ||||||
| 
 | //
 | ||||||
| 	testCases := []struct { | // 	testCases := []struct {
 | ||||||
| 		name           string | // 		name           string
 | ||||||
| 		skipAuthRegex  []string | // 		skipAuthRegex  []string
 | ||||||
| 		skipAuthRoutes []string | // 		skipAuthRoutes []string
 | ||||||
| 		expectedRoutes []expectedAllowedRoute | // 		expectedRoutes []expectedAllowedRoute
 | ||||||
| 		shouldError    bool | // 		shouldError    bool
 | ||||||
| 	}{ | // 	}{
 | ||||||
| 		{ | // 		{
 | ||||||
| 			name:           "No skip auth configured", | // 			name:           "No skip auth configured",
 | ||||||
| 			skipAuthRegex:  []string{}, | // 			skipAuthRegex:  []string{},
 | ||||||
| 			skipAuthRoutes: []string{}, | // 			skipAuthRoutes: []string{},
 | ||||||
| 			expectedRoutes: []expectedAllowedRoute{}, | // 			expectedRoutes: []expectedAllowedRoute{},
 | ||||||
| 			shouldError:    false, | // 			shouldError:    false,
 | ||||||
| 		}, | // 		},
 | ||||||
| 		{ | // 		{
 | ||||||
| 			name: "Only skipAuthRegex configured", | // 			name: "Only skipAuthRegex configured",
 | ||||||
| 			skipAuthRegex: []string{ | // 			skipAuthRegex: []string{
 | ||||||
| 				"^/foo/bar", | // 				"^/foo/bar",
 | ||||||
| 				"^/baz/[0-9]+/thing", | // 				"^/baz/[0-9]+/thing",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			skipAuthRoutes: []string{}, | // 			skipAuthRoutes: []string{},
 | ||||||
| 			expectedRoutes: []expectedAllowedRoute{ | // 			expectedRoutes: []expectedAllowedRoute{
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "", | // 					method:      "",
 | ||||||
| 					regexString: "^/foo/bar", | // 					regexString: "^/foo/bar",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "", | // 					method:      "",
 | ||||||
| 					regexString: "^/baz/[0-9]+/thing", | // 					regexString: "^/baz/[0-9]+/thing",
 | ||||||
| 				}, | // 				},
 | ||||||
| 			}, | // 			},
 | ||||||
| 			shouldError: false, | // 			shouldError: false,
 | ||||||
| 		}, | // 		},
 | ||||||
| 		{ | // 		{
 | ||||||
| 			name:          "Only skipAuthRoutes configured", | // 			name:          "Only skipAuthRoutes configured",
 | ||||||
| 			skipAuthRegex: []string{}, | // 			skipAuthRegex: []string{},
 | ||||||
| 			skipAuthRoutes: []string{ | // 			skipAuthRoutes: []string{
 | ||||||
| 				"GET=^/foo/bar", | // 				"GET=^/foo/bar",
 | ||||||
| 				"POST=^/baz/[0-9]+/thing", | // 				"POST=^/baz/[0-9]+/thing",
 | ||||||
| 				"^/all/methods$", | // 				"^/all/methods$",
 | ||||||
| 				"WEIRD=^/methods/are/allowed", | // 				"WEIRD=^/methods/are/allowed",
 | ||||||
| 				"PATCH=/second/equals?are=handled&just=fine", | // 				"PATCH=/second/equals?are=handled&just=fine",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			expectedRoutes: []expectedAllowedRoute{ | // 			expectedRoutes: []expectedAllowedRoute{
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "GET", | // 					method:      "GET",
 | ||||||
| 					regexString: "^/foo/bar", | // 					regexString: "^/foo/bar",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "POST", | // 					method:      "POST",
 | ||||||
| 					regexString: "^/baz/[0-9]+/thing", | // 					regexString: "^/baz/[0-9]+/thing",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "", | // 					method:      "",
 | ||||||
| 					regexString: "^/all/methods$", | // 					regexString: "^/all/methods$",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "WEIRD", | // 					method:      "WEIRD",
 | ||||||
| 					regexString: "^/methods/are/allowed", | // 					regexString: "^/methods/are/allowed",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "PATCH", | // 					method:      "PATCH",
 | ||||||
| 					regexString: "/second/equals?are=handled&just=fine", | // 					regexString: "/second/equals?are=handled&just=fine",
 | ||||||
| 				}, | // 				},
 | ||||||
| 			}, | // 			},
 | ||||||
| 			shouldError: false, | // 			shouldError: false,
 | ||||||
| 		}, | // 		},
 | ||||||
| 		{ | // 		{
 | ||||||
| 			name: "Both skipAuthRegexes and skipAuthRoutes configured", | // 			name: "Both skipAuthRegexes and skipAuthRoutes configured",
 | ||||||
| 			skipAuthRegex: []string{ | // 			skipAuthRegex: []string{
 | ||||||
| 				"^/foo/bar/regex", | // 				"^/foo/bar/regex",
 | ||||||
| 				"^/baz/[0-9]+/thing/regex", | // 				"^/baz/[0-9]+/thing/regex",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			skipAuthRoutes: []string{ | // 			skipAuthRoutes: []string{
 | ||||||
| 				"GET=^/foo/bar", | // 				"GET=^/foo/bar",
 | ||||||
| 				"POST=^/baz/[0-9]+/thing", | // 				"POST=^/baz/[0-9]+/thing",
 | ||||||
| 				"^/all/methods$", | // 				"^/all/methods$",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			expectedRoutes: []expectedAllowedRoute{ | // 			expectedRoutes: []expectedAllowedRoute{
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "", | // 					method:      "",
 | ||||||
| 					regexString: "^/foo/bar/regex", | // 					regexString: "^/foo/bar/regex",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "", | // 					method:      "",
 | ||||||
| 					regexString: "^/baz/[0-9]+/thing/regex", | // 					regexString: "^/baz/[0-9]+/thing/regex",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "GET", | // 					method:      "GET",
 | ||||||
| 					regexString: "^/foo/bar", | // 					regexString: "^/foo/bar",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "POST", | // 					method:      "POST",
 | ||||||
| 					regexString: "^/baz/[0-9]+/thing", | // 					regexString: "^/baz/[0-9]+/thing",
 | ||||||
| 				}, | // 				},
 | ||||||
| 				{ | // 				{
 | ||||||
| 					method:      "", | // 					method:      "",
 | ||||||
| 					regexString: "^/all/methods$", | // 					regexString: "^/all/methods$",
 | ||||||
| 				}, | // 				},
 | ||||||
| 			}, | // 			},
 | ||||||
| 			shouldError: false, | // 			shouldError: false,
 | ||||||
| 		}, | // 		},
 | ||||||
| 		{ | // 		{
 | ||||||
| 			name: "Invalid skipAuthRegex entry", | // 			name: "Invalid skipAuthRegex entry",
 | ||||||
| 			skipAuthRegex: []string{ | // 			skipAuthRegex: []string{
 | ||||||
| 				"^/foo/bar", | // 				"^/foo/bar",
 | ||||||
| 				"^/baz/[0-9]+/thing", | // 				"^/baz/[0-9]+/thing",
 | ||||||
| 				"(bad[regex", | // 				"(bad[regex",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			skipAuthRoutes: []string{}, | // 			skipAuthRoutes: []string{},
 | ||||||
| 			expectedRoutes: []expectedAllowedRoute{}, | // 			expectedRoutes: []expectedAllowedRoute{},
 | ||||||
| 			shouldError:    true, | // 			shouldError:    true,
 | ||||||
| 		}, | // 		},
 | ||||||
| 		{ | // 		{
 | ||||||
| 			name:          "Invalid skipAuthRoutes entry", | // 			name:          "Invalid skipAuthRoutes entry",
 | ||||||
| 			skipAuthRegex: []string{}, | // 			skipAuthRegex: []string{},
 | ||||||
| 			skipAuthRoutes: []string{ | // 			skipAuthRoutes: []string{
 | ||||||
| 				"GET=^/foo/bar", | // 				"GET=^/foo/bar",
 | ||||||
| 				"POST=^/baz/[0-9]+/thing", | // 				"POST=^/baz/[0-9]+/thing",
 | ||||||
| 				"^/all/methods$", | // 				"^/all/methods$",
 | ||||||
| 				"PUT=(bad[regex", | // 				"PUT=(bad[regex",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			expectedRoutes: []expectedAllowedRoute{}, | // 			expectedRoutes: []expectedAllowedRoute{},
 | ||||||
| 			shouldError:    true, | // 			shouldError:    true,
 | ||||||
| 		}, | // 		},
 | ||||||
| 	} | // 	}
 | ||||||
| 
 | //
 | ||||||
| 	for _, tc := range testCases { | // 	for _, tc := range testCases {
 | ||||||
| 		t.Run(tc.name, func(t *testing.T) { | // 		t.Run(tc.name, func(t *testing.T) {
 | ||||||
| 			opts := &options.Options{ | // 			opts := &options.Options{
 | ||||||
| 				SkipAuthRegex:  tc.skipAuthRegex, | // 				SkipAuthRegex:  tc.skipAuthRegex,
 | ||||||
| 				SkipAuthRoutes: tc.skipAuthRoutes, | // 				SkipAuthRoutes: tc.skipAuthRoutes,
 | ||||||
| 			} | // 			}
 | ||||||
| 			routes, err := buildRoutesAllowlist(opts) | // 			routes, err := buildRoutesAllowlist(opts)
 | ||||||
| 			if tc.shouldError { | // 			if tc.shouldError {
 | ||||||
| 				assert.Error(t, err) | // 				assert.Error(t, err)
 | ||||||
| 				return | // 				return
 | ||||||
| 			} | // 			}
 | ||||||
| 			assert.NoError(t, err) | // 			assert.NoError(t, err)
 | ||||||
| 
 | //
 | ||||||
| 			for i, route := range routes { | // 			for i, route := range routes {
 | ||||||
| 				assert.Greater(t, len(tc.expectedRoutes), i) | // 				assert.Greater(t, len(tc.expectedRoutes), i)
 | ||||||
| 				assert.Equal(t, route.method, tc.expectedRoutes[i].method) | // 				assert.Equal(t, route.method, tc.expectedRoutes[i].method)
 | ||||||
| 				assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString) | // 				assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString)
 | ||||||
| 			} | // 			}
 | ||||||
| 		}) | // 		})
 | ||||||
| 	} | // 	}
 | ||||||
| } | // }
 | ||||||
| 
 | 
 | ||||||
| func TestAllowedRequest(t *testing.T) { | func TestAllowedRequest(t *testing.T) { | ||||||
| 	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | @ -2356,12 +2378,20 @@ func TestAllowedRequest(t *testing.T) { | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	opts.SkipAuthRegex = []string{ | 	opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||||
| 		"^/skip/auth/regex$", | 		{ | ||||||
| 	} | 			ID:     "regex", | ||||||
| 	opts.SkipAuthRoutes = []string{ | 			Path:   "^/skip/auth/regex$", | ||||||
| 		"GET=^/skip/auth/routes/get", | 			Policy: options.AllowPolicy, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ID:      "route", | ||||||
|  | 			Path:    "^/skip/auth/routes/get", | ||||||
|  | 			Methods: []string{http.MethodGet}, | ||||||
|  | 			Policy:  options.AllowPolicy, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	err := validation.Validate(opts) | 	err := validation.Validate(opts) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) | 	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) | ||||||
|  | @ -2417,7 +2447,6 @@ func TestAllowedRequest(t *testing.T) { | ||||||
| 		t.Run(tc.name, func(t *testing.T) { | 		t.Run(tc.name, func(t *testing.T) { | ||||||
| 			req, err := http.NewRequest(tc.method, tc.url, nil) | 			req, err := http.NewRequest(tc.method, tc.url, nil) | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
| 			assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req)) |  | ||||||
| 
 | 
 | ||||||
| 			rw := httptest.NewRecorder() | 			rw := httptest.NewRecorder() | ||||||
| 			proxy.ServeHTTP(rw, req) | 			proxy.ServeHTTP(rw, req) | ||||||
|  | @ -2670,8 +2699,18 @@ func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) { | ||||||
| 	for _, tc := range testCases { | 	for _, tc := range testCases { | ||||||
| 		t.Run(tc.name, func(t *testing.T) { | 		t.Run(tc.name, func(t *testing.T) { | ||||||
| 			test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) { | 			test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) { | ||||||
| 				opts.SkipAuthPreflight = true | 				opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||||
| 				opts.TrustedIPs = []string{"1.2.3.4"} | 					{ | ||||||
|  | 						ID:      "skip-auth-preflight", | ||||||
|  | 						Methods: []string{http.MethodOptions}, | ||||||
|  | 						Policy:  options.AllowPolicy, | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						ID:     "trusted-ips", | ||||||
|  | 						IPs:    []string{"1.2.3.4"}, | ||||||
|  | 						Policy: options.AllowPolicy, | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
| 			}) | 			}) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				t.Fatal(err) | 				t.Fatal(err) | ||||||
|  |  | ||||||
|  | @ -4,32 +4,24 @@ import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type AuthorizationPolicy int |  | ||||||
| 
 |  | ||||||
| const ( |  | ||||||
| 	NonePolicy AuthorizationPolicy = iota |  | ||||||
| 	AllowPolicy |  | ||||||
| 	DelegatePolicy |  | ||||||
| 	DenyPolicy |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type RuleSet interface { | type RuleSet interface { | ||||||
| 	MatchesRequest(req *http.Request) AuthorizationPolicy | 	MatchesRequest(req *http.Request) middlewareapi.AuthorizationPolicy | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type rule struct { | type rule struct { | ||||||
| 	conditions []condition | 	conditions []condition | ||||||
| 	policy     AuthorizationPolicy | 	policy     middlewareapi.AuthorizationPolicy | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r rule) matches(req *http.Request) AuthorizationPolicy { | func (r rule) matches(req *http.Request) middlewareapi.AuthorizationPolicy { | ||||||
| 	for _, condition := range r.conditions { | 	for _, condition := range r.conditions { | ||||||
| 		if !condition.matches(req) { | 		if !condition.matches(req) { | ||||||
| 			// One of the conditions didn't match so this rule does not apply
 | 			// One of the conditions didn't match so this rule does not apply
 | ||||||
| 			return NonePolicy | 			return middlewareapi.OmittedPolicy | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	// If all conditions match, return the configured rule policy
 | 	// If all conditions match, return the configured rule policy
 | ||||||
|  | @ -60,17 +52,17 @@ func newRule(authRule options.AuthorizationRule, getClientIPFunc func(*http.Requ | ||||||
| 		conditions = append(conditions, condition) | 		conditions = append(conditions, condition) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var policy AuthorizationPolicy | 	var policy middlewareapi.AuthorizationPolicy | ||||||
| 	switch authRule.Policy { | 	switch authRule.Policy { | ||||||
| 	case options.AllowPolicy: | 	case options.AllowPolicy: | ||||||
| 		policy = AllowPolicy | 		policy = middlewareapi.AllowPolicy | ||||||
| 	case options.DelegatePolicy: | 	case options.DelegatePolicy: | ||||||
| 		policy = DelegatePolicy | 		policy = middlewareapi.DelegatePolicy | ||||||
| 	case options.DenyPolicy: | 	case options.DenyPolicy: | ||||||
| 		policy = DenyPolicy | 		policy = middlewareapi.DenyPolicy | ||||||
| 	default: | 	default: | ||||||
| 		// This shouldn't be the case and should be prevented by validation
 | 		// This shouldn't be the case and should be prevented by validation
 | ||||||
| 		policy = NonePolicy | 		policy = middlewareapi.OmittedPolicy | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return rule{ | 	return rule{ | ||||||
|  | @ -83,15 +75,15 @@ type ruleSet struct { | ||||||
| 	rules []rule | 	rules []rule | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r ruleSet) MatchesRequest(req *http.Request) AuthorizationPolicy { | func (r ruleSet) MatchesRequest(req *http.Request) middlewareapi.AuthorizationPolicy { | ||||||
| 	for _, rule := range r.rules { | 	for _, rule := range r.rules { | ||||||
| 		if policy := rule.matches(req); policy != NonePolicy { | 		if policy := rule.matches(req); policy != middlewareapi.OmittedPolicy { | ||||||
| 			// The rule applies to this request, return its policy
 | 			// The rule applies to this request, return its policy
 | ||||||
| 			return policy | 			return policy | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	// No rules matched
 | 	// No rules matched
 | ||||||
| 	return NonePolicy | 	return middlewareapi.OmittedPolicy | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewRuleSet(requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (RuleSet, error) { | func NewRuleSet(requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (RuleSet, error) { | ||||||
|  |  | ||||||
|  | @ -6,10 +6,11 @@ import ( | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var result AuthorizationPolicy | var result middlewareapi.AuthorizationPolicy | ||||||
| 
 | 
 | ||||||
| func benchmarkRuleSetMatches(ruleCount int, b *testing.B) { | func benchmarkRuleSetMatches(ruleCount int, b *testing.B) { | ||||||
| 	rule1 := options.AuthorizationRule{ | 	rule1 := options.AuthorizationRule{ | ||||||
|  | @ -53,10 +54,10 @@ func benchmarkRuleSetMatches(ruleCount int, b *testing.B) { | ||||||
| 
 | 
 | ||||||
| 	req := httptest.NewRequest("GET", "/foo/bar/baz", nil) | 	req := httptest.NewRequest("GET", "/foo/bar/baz", nil) | ||||||
| 
 | 
 | ||||||
| 	var r AuthorizationPolicy | 	var r middlewareapi.AuthorizationPolicy | ||||||
| 	for n := 0; n < b.N; n++ { | 	for n := 0; n < b.N; n++ { | ||||||
| 		r = ruleSet.MatchesRequest(req) | 		r = ruleSet.MatchesRequest(req) | ||||||
| 		if r != NonePolicy { | 		if r != middlewareapi.OmittedPolicy { | ||||||
| 			b.Fatal("expected policy not to match") | 			b.Fatal("expected policy not to match") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -0,0 +1,61 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authorization" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func NewRequestAuthorization(writer pagewriter.Writer, requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (alice.Constructor, error) { | ||||||
|  | 	ruleset, err := authorization.NewRuleSet(requestRules, getClientIPFunc) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not initialise ruleset: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ra := &requestAuthorizer{ | ||||||
|  | 		ruleset: ruleset, | ||||||
|  | 		writer:  writer, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return ra.checkRequestAuthorization, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type requestAuthorizer struct { | ||||||
|  | 	ruleset authorization.RuleSet | ||||||
|  | 	writer  pagewriter.Writer | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (r *requestAuthorizer) checkRequestAuthorization(next http.Handler) http.Handler { | ||||||
|  | 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 		scope := middlewareapi.GetRequestScope(req) | ||||||
|  | 		// If scope is nil, this will panic.
 | ||||||
|  | 		// A scope should always be injected before this handler is called.
 | ||||||
|  | 		if scope.Authorization.Policy != middlewareapi.OmittedPolicy { | ||||||
|  | 			// The request was already authorized, pass to the next handler
 | ||||||
|  | 			next.ServeHTTP(rw, req) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		policy := r.ruleset.MatchesRequest(req) | ||||||
|  | 		switch policy { | ||||||
|  | 		case middlewareapi.AllowPolicy, middlewareapi.DelegatePolicy: | ||||||
|  | 			scope.Authorization.Type = middlewareapi.RequestAuthorization | ||||||
|  | 			scope.Authorization.Policy = policy | ||||||
|  | 		case middlewareapi.DenyPolicy: | ||||||
|  | 			r.writer.WriteErrorPage(rw, pagewriter.ErrorPageOpts{ | ||||||
|  | 				Status:    http.StatusForbidden, | ||||||
|  | 				RequestID: scope.RequestID, | ||||||
|  | 				AppError:  "Request denied by authorization policy", | ||||||
|  | 				Messages:  []interface{}{"Request denied by authorization policy"}, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		next.ServeHTTP(rw, req) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | @ -1,70 +0,0 @@ | ||||||
| package validation |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"os" |  | ||||||
| 	"regexp" |  | ||||||
| 	"strings" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func validateAllowlists(o *options.Options) []string { |  | ||||||
| 	msgs := []string{} |  | ||||||
| 
 |  | ||||||
| 	msgs = append(msgs, validateRoutes(o)...) |  | ||||||
| 	msgs = append(msgs, validateRegexes(o)...) |  | ||||||
| 	msgs = append(msgs, validateTrustedIPs(o)...) |  | ||||||
| 
 |  | ||||||
| 	if len(o.TrustedIPs) > 0 && o.ReverseProxy { |  | ||||||
| 		_, err := fmt.Fprintln(os.Stderr, "WARNING: mixing --trusted-ip with --reverse-proxy is a potential security vulnerability. An attacker can inject a trusted IP into an X-Real-IP or X-Forwarded-For header if they aren't properly protected outside of oauth2-proxy") |  | ||||||
| 		if err != nil { |  | ||||||
| 			panic(err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return msgs |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // validateRoutes validates method=path routes passed with options.SkipAuthRoutes
 |  | ||||||
| func validateRoutes(o *options.Options) []string { |  | ||||||
| 	msgs := []string{} |  | ||||||
| 	for _, route := range o.SkipAuthRoutes { |  | ||||||
| 		var regex string |  | ||||||
| 		parts := strings.SplitN(route, "=", 2) |  | ||||||
| 		if len(parts) == 1 { |  | ||||||
| 			regex = parts[0] |  | ||||||
| 		} else { |  | ||||||
| 			regex = parts[1] |  | ||||||
| 		} |  | ||||||
| 		_, err := regexp.Compile(regex) |  | ||||||
| 		if err != nil { |  | ||||||
| 			msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return msgs |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // validateRegex validates regex paths passed with options.SkipAuthRegex
 |  | ||||||
| func validateRegexes(o *options.Options) []string { |  | ||||||
| 	msgs := []string{} |  | ||||||
| 	for _, regex := range o.SkipAuthRegex { |  | ||||||
| 		_, err := regexp.Compile(regex) |  | ||||||
| 		if err != nil { |  | ||||||
| 			msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return msgs |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // validateTrustedIPs validates IP/CIDRs for IP based allowlists
 |  | ||||||
| func validateTrustedIPs(o *options.Options) []string { |  | ||||||
| 	msgs := []string{} |  | ||||||
| 	for i, ipStr := range o.TrustedIPs { |  | ||||||
| 		if nil == ip.ParseIPNet(ipStr) { |  | ||||||
| 			msgs = append(msgs, fmt.Sprintf("trusted_ips[%d] (%s) could not be recognized", i, ipStr)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return msgs |  | ||||||
| } |  | ||||||
|  | @ -1,125 +1,125 @@ | ||||||
| package validation | package validation | ||||||
| 
 | 
 | ||||||
| import ( | // import (
 | ||||||
| 	. "github.com/onsi/ginkgo" | // 	. "github.com/onsi/ginkgo"
 | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | // 	. "github.com/onsi/ginkgo/extensions/table"
 | ||||||
| 	. "github.com/onsi/gomega" | // 	. "github.com/onsi/gomega"
 | ||||||
| 
 | //
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | // 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | ||||||
| ) | // )
 | ||||||
| 
 | //
 | ||||||
| var _ = Describe("Allowlist", func() { | // var _ = Describe("Allowlist", func() {
 | ||||||
| 	type validateRoutesTableInput struct { | // 	type validateRoutesTableInput struct {
 | ||||||
| 		routes     []string | // 		routes     []string
 | ||||||
| 		errStrings []string | // 		errStrings []string
 | ||||||
| 	} | // 	}
 | ||||||
| 
 | //
 | ||||||
| 	type validateRegexesTableInput struct { | // 	type validateRegexesTableInput struct {
 | ||||||
| 		regexes    []string | // 		regexes    []string
 | ||||||
| 		errStrings []string | // 		errStrings []string
 | ||||||
| 	} | // 	}
 | ||||||
| 
 | //
 | ||||||
| 	type validateTrustedIPsTableInput struct { | // 	type validateTrustedIPsTableInput struct {
 | ||||||
| 		trustedIPs []string | // 		trustedIPs []string
 | ||||||
| 		errStrings []string | // 		errStrings []string
 | ||||||
| 	} | // 	}
 | ||||||
| 
 | //
 | ||||||
| 	DescribeTable("validateRoutes", | // 	DescribeTable("validateRoutes",
 | ||||||
| 		func(r *validateRoutesTableInput) { | // 		func(r *validateRoutesTableInput) {
 | ||||||
| 			opts := &options.Options{ | // 			opts := &options.Options{
 | ||||||
| 				SkipAuthRoutes: r.routes, | // 				SkipAuthRoutes: r.routes,
 | ||||||
| 			} | // 			}
 | ||||||
| 			Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings)) | // 			Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings))
 | ||||||
| 		}, | // 		},
 | ||||||
| 		Entry("Valid regex routes", &validateRoutesTableInput{ | // 		Entry("Valid regex routes", &validateRoutesTableInput{
 | ||||||
| 			routes: []string{ | // 			routes: []string{
 | ||||||
| 				"/foo", | // 				"/foo",
 | ||||||
| 				"POST=/foo/bar", | // 				"POST=/foo/bar",
 | ||||||
| 				"PUT=^/foo/bar$", | // 				"PUT=^/foo/bar$",
 | ||||||
| 				"DELETE=/crazy/(?:regex)?/[^/]+/stuff$", | // 				"DELETE=/crazy/(?:regex)?/[^/]+/stuff$",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			errStrings: []string{}, | // 			errStrings: []string{},
 | ||||||
| 		}), | // 		}),
 | ||||||
| 		Entry("Bad regexes do not compile", &validateRoutesTableInput{ | // 		Entry("Bad regexes do not compile", &validateRoutesTableInput{
 | ||||||
| 			routes: []string{ | // 			routes: []string{
 | ||||||
| 				"POST=/(foo", | // 				"POST=/(foo",
 | ||||||
| 				"OPTIONS=/foo/bar)", | // 				"OPTIONS=/foo/bar)",
 | ||||||
| 				"GET=^]/foo/bar[$", | // 				"GET=^]/foo/bar[$",
 | ||||||
| 				"GET=^]/foo/bar[$", | // 				"GET=^]/foo/bar[$",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			errStrings: []string{ | // 			errStrings: []string{
 | ||||||
| 				"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", | // 				"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`",
 | ||||||
| 				"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", | // 				"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`",
 | ||||||
| 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", | // 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
 | ||||||
| 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", | // 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
 | ||||||
| 			}, | // 			},
 | ||||||
| 		}), | // 		}),
 | ||||||
| 	) | // 	)
 | ||||||
| 
 | //
 | ||||||
| 	DescribeTable("validateRegexes", | // 	DescribeTable("validateRegexes",
 | ||||||
| 		func(r *validateRegexesTableInput) { | // 		func(r *validateRegexesTableInput) {
 | ||||||
| 			opts := &options.Options{ | // 			opts := &options.Options{
 | ||||||
| 				SkipAuthRegex: r.regexes, | // 				SkipAuthRegex: r.regexes,
 | ||||||
| 			} | // 			}
 | ||||||
| 			Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings)) | // 			Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings))
 | ||||||
| 		}, | // 		},
 | ||||||
| 		Entry("Valid regex routes", &validateRegexesTableInput{ | // 		Entry("Valid regex routes", &validateRegexesTableInput{
 | ||||||
| 			regexes: []string{ | // 			regexes: []string{
 | ||||||
| 				"/foo", | // 				"/foo",
 | ||||||
| 				"/foo/bar", | // 				"/foo/bar",
 | ||||||
| 				"^/foo/bar$", | // 				"^/foo/bar$",
 | ||||||
| 				"/crazy/(?:regex)?/[^/]+/stuff$", | // 				"/crazy/(?:regex)?/[^/]+/stuff$",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			errStrings: []string{}, | // 			errStrings: []string{},
 | ||||||
| 		}), | // 		}),
 | ||||||
| 		Entry("Bad regexes do not compile", &validateRegexesTableInput{ | // 		Entry("Bad regexes do not compile", &validateRegexesTableInput{
 | ||||||
| 			regexes: []string{ | // 			regexes: []string{
 | ||||||
| 				"/(foo", | // 				"/(foo",
 | ||||||
| 				"/foo/bar)", | // 				"/foo/bar)",
 | ||||||
| 				"^]/foo/bar[$", | // 				"^]/foo/bar[$",
 | ||||||
| 				"^]/foo/bar[$", | // 				"^]/foo/bar[$",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			errStrings: []string{ | // 			errStrings: []string{
 | ||||||
| 				"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", | // 				"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`",
 | ||||||
| 				"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", | // 				"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`",
 | ||||||
| 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", | // 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
 | ||||||
| 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", | // 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
 | ||||||
| 			}, | // 			},
 | ||||||
| 		}), | // 		}),
 | ||||||
| 	) | // 	)
 | ||||||
| 
 | //
 | ||||||
| 	DescribeTable("validateTrustedIPs", | // 	DescribeTable("validateTrustedIPs",
 | ||||||
| 		func(t *validateTrustedIPsTableInput) { | // 		func(t *validateTrustedIPsTableInput) {
 | ||||||
| 			opts := &options.Options{ | // 			opts := &options.Options{
 | ||||||
| 				TrustedIPs: t.trustedIPs, | // 				TrustedIPs: t.trustedIPs,
 | ||||||
| 			} | // 			}
 | ||||||
| 			Expect(validateTrustedIPs(opts)).To(ConsistOf(t.errStrings)) | // 			Expect(validateTrustedIPs(opts)).To(ConsistOf(t.errStrings))
 | ||||||
| 		}, | // 		},
 | ||||||
| 		Entry("Non-overlapping valid IPs", &validateTrustedIPsTableInput{ | // 		Entry("Non-overlapping valid IPs", &validateTrustedIPsTableInput{
 | ||||||
| 			trustedIPs: []string{ | // 			trustedIPs: []string{
 | ||||||
| 				"127.0.0.1", | // 				"127.0.0.1",
 | ||||||
| 				"10.32.0.1/32", | // 				"10.32.0.1/32",
 | ||||||
| 				"43.36.201.0/24", | // 				"43.36.201.0/24",
 | ||||||
| 				"::1", | // 				"::1",
 | ||||||
| 				"2a12:105:ee7:9234:0:0:0:0/64", | // 				"2a12:105:ee7:9234:0:0:0:0/64",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			errStrings: []string{}, | // 			errStrings: []string{},
 | ||||||
| 		}), | // 		}),
 | ||||||
| 		Entry("Overlapping valid IPs", &validateTrustedIPsTableInput{ | // 		Entry("Overlapping valid IPs", &validateTrustedIPsTableInput{
 | ||||||
| 			trustedIPs: []string{ | // 			trustedIPs: []string{
 | ||||||
| 				"135.180.78.199", | // 				"135.180.78.199",
 | ||||||
| 				"135.180.78.199/32", | // 				"135.180.78.199/32",
 | ||||||
| 				"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4", | // 				"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4",
 | ||||||
| 				"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4/128", | // 				"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4/128",
 | ||||||
| 			}, | // 			},
 | ||||||
| 			errStrings: []string{}, | // 			errStrings: []string{},
 | ||||||
| 		}), | // 		}),
 | ||||||
| 		Entry("Invalid IPs", &validateTrustedIPsTableInput{ | // 		Entry("Invalid IPs", &validateTrustedIPsTableInput{
 | ||||||
| 			trustedIPs: []string{"[::1]", "alkwlkbn/32"}, | // 			trustedIPs: []string{"[::1]", "alkwlkbn/32"},
 | ||||||
| 			errStrings: []string{ | // 			errStrings: []string{
 | ||||||
| 				"trusted_ips[0] ([::1]) could not be recognized", | // 				"trusted_ips[0] ([::1]) could not be recognized",
 | ||||||
| 				"trusted_ips[1] (alkwlkbn/32) could not be recognized", | // 				"trusted_ips[1] (alkwlkbn/32) could not be recognized",
 | ||||||
| 			}, | // 			},
 | ||||||
| 		}), | // 		}),
 | ||||||
| 	) | // 	)
 | ||||||
| }) | // })
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,94 @@ | ||||||
|  | package validation | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"regexp" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func validateAuthorization(authorization options.Authorization, reverseProxy bool) []string { | ||||||
|  | 	msgs := []string{} | ||||||
|  | 
 | ||||||
|  | 	msgs = append(msgs, validateRequestRules(authorization.RequestRules, reverseProxy)...) | ||||||
|  | 
 | ||||||
|  | 	return msgs | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func validateRequestRules(rules []options.AuthorizationRule, reverseProxy bool) []string { | ||||||
|  | 	msgs := []string{} | ||||||
|  | 
 | ||||||
|  | 	ids := make(map[string]struct{}) | ||||||
|  | 
 | ||||||
|  | 	for _, rule := range rules { | ||||||
|  | 		msgs = append(msgs, validateRequestRule(ids, rule, reverseProxy)...) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return msgs | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func validateRequestRule(ids map[string]struct{}, rule options.AuthorizationRule, reverseProxy bool) []string { | ||||||
|  | 	msgs := []string{} | ||||||
|  | 
 | ||||||
|  | 	if rule.ID == "" { | ||||||
|  | 		msgs = append(msgs, "request rule has empty ID: IDs are required for all request rules") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, ok := ids[rule.ID]; ok { | ||||||
|  | 		msgs = append(msgs, fmt.Sprintf("multiple request rules found with ID %q: request rule IDs must be unique", rule.ID)) | ||||||
|  | 	} | ||||||
|  | 	ids[rule.ID] = struct{}{} | ||||||
|  | 
 | ||||||
|  | 	msgs = append(msgs, validateRequestRulePolicy(rule.ID, rule.Policy)...) | ||||||
|  | 	msgs = append(msgs, validateRequestRulePath(rule.ID, rule.Path)...) | ||||||
|  | 	msgs = append(msgs, validateRequestRuleIPs(rule.ID, rule.IPs, reverseProxy)...) | ||||||
|  | 
 | ||||||
|  | 	return msgs | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func validateRequestRulePolicy(ruleID string, policy options.AuthorizationPolicy) []string { | ||||||
|  | 	msgs := []string{} | ||||||
|  | 
 | ||||||
|  | 	switch policy { | ||||||
|  | 	case options.AllowPolicy, options.DenyPolicy, options.DelegatePolicy: | ||||||
|  | 		// Do nothing for valid options
 | ||||||
|  | 	default: | ||||||
|  | 		msgs = append(msgs, fmt.Sprintf("request rule %q has invalid policy (%s): policy must be one of %s, %s or %s", ruleID, policy, options.AllowPolicy, options.DenyPolicy, options.DelegatePolicy)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return msgs | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // validateRequestRulePath validates paths for path/regex based conditions
 | ||||||
|  | func validateRequestRulePath(ruleID string, path string) []string { | ||||||
|  | 	msgs := []string{} | ||||||
|  | 
 | ||||||
|  | 	_, err := regexp.Compile(path) | ||||||
|  | 	if err != nil { | ||||||
|  | 		msgs = append(msgs, fmt.Sprintf("error compiling path regex (%s) for rule %q: %v", path, ruleID, err)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return msgs | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // validateRequestRuleIPs validates IP/CIDRs for IP based conditions.
 | ||||||
|  | func validateRequestRuleIPs(ruleID string, ips []string, reverseProxy bool) []string { | ||||||
|  | 	msgs := []string{} | ||||||
|  | 
 | ||||||
|  | 	if len(ips) > 0 && reverseProxy { | ||||||
|  | 		_, err := fmt.Fprintln(os.Stderr, "WARNING: mixing IP authorization with --reverse-proxy is a potential security vulnerability. An attacker can inject a trusted IP into an X-Real-IP or X-Forwarded-For header if they aren't properly protected outside of oauth2-proxy") | ||||||
|  | 		if err != nil { | ||||||
|  | 			panic(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for i, ipStr := range ips { | ||||||
|  | 		if nil == ip.ParseIPNet(ipStr) { | ||||||
|  | 			msgs = append(msgs, fmt.Sprintf("rule %q IP [%d] (%s) could not be recognized", ruleID, i, ipStr)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return msgs | ||||||
|  | } | ||||||
|  | @ -20,6 +20,7 @@ import ( | ||||||
| // are of the correct format
 | // are of the correct format
 | ||||||
| func Validate(o *options.Options) error { | func Validate(o *options.Options) error { | ||||||
| 	msgs := validateCookie(o.Cookie) | 	msgs := validateCookie(o.Cookie) | ||||||
|  | 	msgs = append(msgs, validateAuthorization(o.Authorization, o.ReverseProxy)...) | ||||||
| 	msgs = append(msgs, validateSessionCookieMinimal(o)...) | 	msgs = append(msgs, validateSessionCookieMinimal(o)...) | ||||||
| 	msgs = append(msgs, validateRedisSessionStore(o)...) | 	msgs = append(msgs, validateRedisSessionStore(o)...) | ||||||
| 	msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...) | 	msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...) | ||||||
|  | @ -96,9 +97,6 @@ func Validate(o *options.Options) error { | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Do this after ReverseProxy validation for TrustedIP coordinated checks
 |  | ||||||
| 	msgs = append(msgs, validateAllowlists(o)...) |  | ||||||
| 
 |  | ||||||
| 	if len(msgs) != 0 { | 	if len(msgs) != 0 { | ||||||
| 		return fmt.Errorf("invalid configuration:\n  %s", | 		return fmt.Errorf("invalid configuration:\n  %s", | ||||||
| 			strings.Join(msgs, "\n  ")) | 			strings.Join(msgs, "\n  ")) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue