WIP
This commit is contained in:
		
							parent
							
								
									374a676c9d
								
							
						
					
					
						commit
						0dbda5dfac
					
				
							
								
								
									
										128
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										128
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -11,7 +11,6 @@ import ( | |||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | @ -61,12 +60,6 @@ var ( | |||
| 	ErrAccessDenied = errors.New("access denied") | ||||
| ) | ||||
| 
 | ||||
| // allowedRoute manages method + path based allowlists
 | ||||
| type allowedRoute struct { | ||||
| 	method    string | ||||
| 	pathRegex *regexp.Regexp | ||||
| } | ||||
| 
 | ||||
| // OAuthProxy is the main authentication proxy
 | ||||
| type OAuthProxy struct { | ||||
| 	CookieOptions *options.Cookie | ||||
|  | @ -74,7 +67,6 @@ type OAuthProxy struct { | |||
| 
 | ||||
| 	SignInPath string | ||||
| 
 | ||||
| 	allowedRoutes       []allowedRoute | ||||
| 	redirectURL         *url.URL // the url to receive requests at
 | ||||
| 	whitelistDomains    []string | ||||
| 	provider            providers.Provider | ||||
|  | @ -83,11 +75,9 @@ type OAuthProxy struct { | |||
| 	basicAuthValidator  basic.Validator | ||||
| 	basicAuthGroups     []string | ||||
| 	SkipProviderButton  bool | ||||
| 	skipAuthPreflight   bool | ||||
| 	skipJwtBearerTokens bool | ||||
| 	forceJSONErrors     bool | ||||
| 	realClientIPParser  ipapi.RealClientIPParser | ||||
| 	trustedIPs          *ip.NetSet | ||||
| 
 | ||||
| 	sessionChain      alice.Chain | ||||
| 	headersChain      alice.Chain | ||||
|  | @ -161,21 +151,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 
 | ||||
| 	logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domains:%s path:%s samesite:%s refresh:%s", opts.Cookie.Name, opts.Cookie.Secure, opts.Cookie.HTTPOnly, opts.Cookie.Expire, strings.Join(opts.Cookie.Domains, ","), opts.Cookie.Path, opts.Cookie.SameSite, refresh) | ||||
| 
 | ||||
| 	trustedIPs := ip.NewNetSet() | ||||
| 	for _, ipStr := range opts.TrustedIPs { | ||||
| 		if ipNet := ip.ParseIPNet(ipStr); ipNet != nil { | ||||
| 			trustedIPs.AddIPNet(*ipNet) | ||||
| 		} else { | ||||
| 			return nil, fmt.Errorf("could not parse IP network (%s)", ipStr) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	allowedRoutes, err := buildRoutesAllowlist(opts) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	preAuthChain, err := buildPreAuthChain(opts) | ||||
| 	preAuthChain, err := buildPreAuthChain(opts, pageWriter) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not build pre-auth chain: %v", err) | ||||
| 	} | ||||
|  | @ -201,14 +177,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		provider:            provider, | ||||
| 		sessionStore:        sessionStore, | ||||
| 		redirectURL:         redirectURL, | ||||
| 		allowedRoutes:       allowedRoutes, | ||||
| 		whitelistDomains:    opts.WhitelistDomains, | ||||
| 		skipAuthPreflight:   opts.SkipAuthPreflight, | ||||
| 		skipJwtBearerTokens: opts.SkipJwtBearerTokens, | ||||
| 		realClientIPParser:  opts.GetRealClientIPParser(), | ||||
| 		SkipProviderButton:  opts.SkipProviderButton, | ||||
| 		forceJSONErrors:     opts.ForceJSONErrors, | ||||
| 		trustedIPs:          trustedIPs, | ||||
| 
 | ||||
| 		basicAuthValidator: basicAuthValidator, | ||||
| 		basicAuthGroups:    opts.HtpasswdUserGroups, | ||||
|  | @ -316,7 +289,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { | |||
| // buildPreAuthChain constructs a chain that should process every request before
 | ||||
| // the OAuth2 Proxy authentication logic kicks in.
 | ||||
| // For example forcing HTTPS or health checks.
 | ||||
| func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | ||||
| func buildPreAuthChain(opts *options.Options, pageWriter pagewriter.Writer) (alice.Chain, error) { | ||||
| 	chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) | ||||
| 
 | ||||
| 	if opts.ForceHTTPS { | ||||
|  | @ -351,6 +324,22 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | |||
| 
 | ||||
| 	chain = chain.Append(middleware.NewRequestMetricsWithDefaultRegistry()) | ||||
| 
 | ||||
| 	requestAuthorization, err := middleware.NewRequestAuthorization(pageWriter, opts.Authorization.RequestRules, func(req *http.Request) net.IP { | ||||
| 		if opts.GetRealClientIPParser() == nil { | ||||
| 			host, _ := util.SplitHostPort(req.RemoteAddr) | ||||
| 			return net.ParseIP(host) | ||||
| 		} | ||||
| 		ip, err := opts.GetRealClientIPParser().GetRealClientIP(req.Header) | ||||
| 		if err != nil { | ||||
| 			return nil | ||||
| 		} | ||||
| 		return ip | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return alice.Chain{}, fmt.Errorf("error initialising request authorization middleware: %w", err) | ||||
| 	} | ||||
| 	chain = chain.Append(requestAuthorization) | ||||
| 
 | ||||
| 	return chain, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -423,53 +412,6 @@ func buildProviderName(p providers.Provider, override string) string { | |||
| 	return p.Data().ProviderName | ||||
| } | ||||
| 
 | ||||
| // buildRoutesAllowlist builds an []allowedRoute  list from either the legacy
 | ||||
| // SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option
 | ||||
| // (method=path support)
 | ||||
| func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { | ||||
| 	routes := make([]allowedRoute, 0, len(opts.SkipAuthRegex)+len(opts.SkipAuthRoutes)) | ||||
| 
 | ||||
| 	for _, path := range opts.SkipAuthRegex { | ||||
| 		compiledRegex, err := regexp.Compile(path) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		logger.Printf("Skipping auth - Method: ALL | Path: %s", path) | ||||
| 		routes = append(routes, allowedRoute{ | ||||
| 			method:    "", | ||||
| 			pathRegex: compiledRegex, | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, methodPath := range opts.SkipAuthRoutes { | ||||
| 		var ( | ||||
| 			method string | ||||
| 			path   string | ||||
| 		) | ||||
| 
 | ||||
| 		parts := strings.SplitN(methodPath, "=", 2) | ||||
| 		if len(parts) == 1 { | ||||
| 			method = "" | ||||
| 			path = parts[0] | ||||
| 		} else { | ||||
| 			method = strings.ToUpper(parts[0]) | ||||
| 			path = parts[1] | ||||
| 		} | ||||
| 
 | ||||
| 		compiledRegex, err := regexp.Compile(path) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		logger.Printf("Skipping auth - Method: %s | Path: %s", method, path) | ||||
| 		routes = append(routes, allowedRoute{ | ||||
| 			method:    method, | ||||
| 			pathRegex: compiledRegex, | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	return routes, nil | ||||
| } | ||||
| 
 | ||||
| // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | ||||
| // stored in the user's session
 | ||||
| func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { | ||||
|  | @ -512,38 +454,8 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i | |||
| 
 | ||||
| // IsAllowedRequest is used to check if auth should be skipped for this request
 | ||||
| func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { | ||||
| 	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" | ||||
| 	return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) | ||||
| } | ||||
| 
 | ||||
| // IsAllowedRoute is used to check if the request method & path is allowed without auth
 | ||||
| func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { | ||||
| 	for _, route := range p.allowedRoutes { | ||||
| 		if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // isTrustedIP is used to check if a request comes from a trusted client IP address.
 | ||||
| func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { | ||||
| 	if p.trustedIPs == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) | ||||
| 		// Possibly spoofed X-Real-IP header
 | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if remoteAddr == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	return p.trustedIPs.Has(remoteAddr) | ||||
| 	scope := middlewareapi.GetRequestScope(req) | ||||
| 	return scope.Authorization.Policy == middlewareapi.AllowPolicy | ||||
| } | ||||
| 
 | ||||
| // SignInPage writes the sign in template to the response
 | ||||
|  |  | |||
|  | @ -1358,7 +1358,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { | |||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	opts.SkipAuthPreflight = true | ||||
| 	opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||
| 		{ | ||||
| 			ID:      "skip-auth-preflight", | ||||
| 			Methods: []string{http.MethodOptions}, | ||||
| 			Policy:  options.AllowPolicy, | ||||
| 		}, | ||||
| 	} | ||||
| 	err := validation.Validate(opts) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
|  | @ -1889,7 +1895,14 @@ func Test_noCacheHeaders(t *testing.T) { | |||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	opts.SkipAuthRegex = []string{".*"} | ||||
| 	opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||
| 		{ | ||||
| 			ID:     "wildcard", | ||||
| 			Path:   ".*", | ||||
| 			Policy: options.AllowPolicy, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	err := validation.Validate(opts) | ||||
| 	assert.NoError(t, err) | ||||
| 	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) | ||||
|  | @ -2161,7 +2174,16 @@ func TestTrustedIPs(t *testing.T) { | |||
| 					}, | ||||
| 				}, | ||||
| 			} | ||||
| 			opts.TrustedIPs = tt.trustedIPs | ||||
| 			if len(tt.trustedIPs) > 0 { | ||||
| 				opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||
| 					{ | ||||
| 						ID:     "trusted-ips", | ||||
| 						IPs:    tt.trustedIPs, | ||||
| 						Policy: options.AllowPolicy, | ||||
| 					}, | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			opts.ReverseProxy = tt.reverseProxy | ||||
| 			opts.RealClientIPHeader = tt.realClientIPHeader | ||||
| 			err := validation.Validate(opts) | ||||
|  | @ -2181,160 +2203,160 @@ func TestTrustedIPs(t *testing.T) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Test_buildRoutesAllowlist(t *testing.T) { | ||||
| 	type expectedAllowedRoute struct { | ||||
| 		method      string | ||||
| 		regexString string | ||||
| 	} | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		name           string | ||||
| 		skipAuthRegex  []string | ||||
| 		skipAuthRoutes []string | ||||
| 		expectedRoutes []expectedAllowedRoute | ||||
| 		shouldError    bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:           "No skip auth configured", | ||||
| 			skipAuthRegex:  []string{}, | ||||
| 			skipAuthRoutes: []string{}, | ||||
| 			expectedRoutes: []expectedAllowedRoute{}, | ||||
| 			shouldError:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Only skipAuthRegex configured", | ||||
| 			skipAuthRegex: []string{ | ||||
| 				"^/foo/bar", | ||||
| 				"^/baz/[0-9]+/thing", | ||||
| 			}, | ||||
| 			skipAuthRoutes: []string{}, | ||||
| 			expectedRoutes: []expectedAllowedRoute{ | ||||
| 				{ | ||||
| 					method:      "", | ||||
| 					regexString: "^/foo/bar", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "", | ||||
| 					regexString: "^/baz/[0-9]+/thing", | ||||
| 				}, | ||||
| 			}, | ||||
| 			shouldError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "Only skipAuthRoutes configured", | ||||
| 			skipAuthRegex: []string{}, | ||||
| 			skipAuthRoutes: []string{ | ||||
| 				"GET=^/foo/bar", | ||||
| 				"POST=^/baz/[0-9]+/thing", | ||||
| 				"^/all/methods$", | ||||
| 				"WEIRD=^/methods/are/allowed", | ||||
| 				"PATCH=/second/equals?are=handled&just=fine", | ||||
| 			}, | ||||
| 			expectedRoutes: []expectedAllowedRoute{ | ||||
| 				{ | ||||
| 					method:      "GET", | ||||
| 					regexString: "^/foo/bar", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "POST", | ||||
| 					regexString: "^/baz/[0-9]+/thing", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "", | ||||
| 					regexString: "^/all/methods$", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "WEIRD", | ||||
| 					regexString: "^/methods/are/allowed", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "PATCH", | ||||
| 					regexString: "/second/equals?are=handled&just=fine", | ||||
| 				}, | ||||
| 			}, | ||||
| 			shouldError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Both skipAuthRegexes and skipAuthRoutes configured", | ||||
| 			skipAuthRegex: []string{ | ||||
| 				"^/foo/bar/regex", | ||||
| 				"^/baz/[0-9]+/thing/regex", | ||||
| 			}, | ||||
| 			skipAuthRoutes: []string{ | ||||
| 				"GET=^/foo/bar", | ||||
| 				"POST=^/baz/[0-9]+/thing", | ||||
| 				"^/all/methods$", | ||||
| 			}, | ||||
| 			expectedRoutes: []expectedAllowedRoute{ | ||||
| 				{ | ||||
| 					method:      "", | ||||
| 					regexString: "^/foo/bar/regex", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "", | ||||
| 					regexString: "^/baz/[0-9]+/thing/regex", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "GET", | ||||
| 					regexString: "^/foo/bar", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "POST", | ||||
| 					regexString: "^/baz/[0-9]+/thing", | ||||
| 				}, | ||||
| 				{ | ||||
| 					method:      "", | ||||
| 					regexString: "^/all/methods$", | ||||
| 				}, | ||||
| 			}, | ||||
| 			shouldError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Invalid skipAuthRegex entry", | ||||
| 			skipAuthRegex: []string{ | ||||
| 				"^/foo/bar", | ||||
| 				"^/baz/[0-9]+/thing", | ||||
| 				"(bad[regex", | ||||
| 			}, | ||||
| 			skipAuthRoutes: []string{}, | ||||
| 			expectedRoutes: []expectedAllowedRoute{}, | ||||
| 			shouldError:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:          "Invalid skipAuthRoutes entry", | ||||
| 			skipAuthRegex: []string{}, | ||||
| 			skipAuthRoutes: []string{ | ||||
| 				"GET=^/foo/bar", | ||||
| 				"POST=^/baz/[0-9]+/thing", | ||||
| 				"^/all/methods$", | ||||
| 				"PUT=(bad[regex", | ||||
| 			}, | ||||
| 			expectedRoutes: []expectedAllowedRoute{}, | ||||
| 			shouldError:    true, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			opts := &options.Options{ | ||||
| 				SkipAuthRegex:  tc.skipAuthRegex, | ||||
| 				SkipAuthRoutes: tc.skipAuthRoutes, | ||||
| 			} | ||||
| 			routes, err := buildRoutesAllowlist(opts) | ||||
| 			if tc.shouldError { | ||||
| 				assert.Error(t, err) | ||||
| 				return | ||||
| 			} | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			for i, route := range routes { | ||||
| 				assert.Greater(t, len(tc.expectedRoutes), i) | ||||
| 				assert.Equal(t, route.method, tc.expectedRoutes[i].method) | ||||
| 				assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| // func Test_buildRoutesAllowlist(t *testing.T) {
 | ||||
| // 	type expectedAllowedRoute struct {
 | ||||
| // 		method      string
 | ||||
| // 		regexString string
 | ||||
| // 	}
 | ||||
| //
 | ||||
| // 	testCases := []struct {
 | ||||
| // 		name           string
 | ||||
| // 		skipAuthRegex  []string
 | ||||
| // 		skipAuthRoutes []string
 | ||||
| // 		expectedRoutes []expectedAllowedRoute
 | ||||
| // 		shouldError    bool
 | ||||
| // 	}{
 | ||||
| // 		{
 | ||||
| // 			name:           "No skip auth configured",
 | ||||
| // 			skipAuthRegex:  []string{},
 | ||||
| // 			skipAuthRoutes: []string{},
 | ||||
| // 			expectedRoutes: []expectedAllowedRoute{},
 | ||||
| // 			shouldError:    false,
 | ||||
| // 		},
 | ||||
| // 		{
 | ||||
| // 			name: "Only skipAuthRegex configured",
 | ||||
| // 			skipAuthRegex: []string{
 | ||||
| // 				"^/foo/bar",
 | ||||
| // 				"^/baz/[0-9]+/thing",
 | ||||
| // 			},
 | ||||
| // 			skipAuthRoutes: []string{},
 | ||||
| // 			expectedRoutes: []expectedAllowedRoute{
 | ||||
| // 				{
 | ||||
| // 					method:      "",
 | ||||
| // 					regexString: "^/foo/bar",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "",
 | ||||
| // 					regexString: "^/baz/[0-9]+/thing",
 | ||||
| // 				},
 | ||||
| // 			},
 | ||||
| // 			shouldError: false,
 | ||||
| // 		},
 | ||||
| // 		{
 | ||||
| // 			name:          "Only skipAuthRoutes configured",
 | ||||
| // 			skipAuthRegex: []string{},
 | ||||
| // 			skipAuthRoutes: []string{
 | ||||
| // 				"GET=^/foo/bar",
 | ||||
| // 				"POST=^/baz/[0-9]+/thing",
 | ||||
| // 				"^/all/methods$",
 | ||||
| // 				"WEIRD=^/methods/are/allowed",
 | ||||
| // 				"PATCH=/second/equals?are=handled&just=fine",
 | ||||
| // 			},
 | ||||
| // 			expectedRoutes: []expectedAllowedRoute{
 | ||||
| // 				{
 | ||||
| // 					method:      "GET",
 | ||||
| // 					regexString: "^/foo/bar",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "POST",
 | ||||
| // 					regexString: "^/baz/[0-9]+/thing",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "",
 | ||||
| // 					regexString: "^/all/methods$",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "WEIRD",
 | ||||
| // 					regexString: "^/methods/are/allowed",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "PATCH",
 | ||||
| // 					regexString: "/second/equals?are=handled&just=fine",
 | ||||
| // 				},
 | ||||
| // 			},
 | ||||
| // 			shouldError: false,
 | ||||
| // 		},
 | ||||
| // 		{
 | ||||
| // 			name: "Both skipAuthRegexes and skipAuthRoutes configured",
 | ||||
| // 			skipAuthRegex: []string{
 | ||||
| // 				"^/foo/bar/regex",
 | ||||
| // 				"^/baz/[0-9]+/thing/regex",
 | ||||
| // 			},
 | ||||
| // 			skipAuthRoutes: []string{
 | ||||
| // 				"GET=^/foo/bar",
 | ||||
| // 				"POST=^/baz/[0-9]+/thing",
 | ||||
| // 				"^/all/methods$",
 | ||||
| // 			},
 | ||||
| // 			expectedRoutes: []expectedAllowedRoute{
 | ||||
| // 				{
 | ||||
| // 					method:      "",
 | ||||
| // 					regexString: "^/foo/bar/regex",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "",
 | ||||
| // 					regexString: "^/baz/[0-9]+/thing/regex",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "GET",
 | ||||
| // 					regexString: "^/foo/bar",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "POST",
 | ||||
| // 					regexString: "^/baz/[0-9]+/thing",
 | ||||
| // 				},
 | ||||
| // 				{
 | ||||
| // 					method:      "",
 | ||||
| // 					regexString: "^/all/methods$",
 | ||||
| // 				},
 | ||||
| // 			},
 | ||||
| // 			shouldError: false,
 | ||||
| // 		},
 | ||||
| // 		{
 | ||||
| // 			name: "Invalid skipAuthRegex entry",
 | ||||
| // 			skipAuthRegex: []string{
 | ||||
| // 				"^/foo/bar",
 | ||||
| // 				"^/baz/[0-9]+/thing",
 | ||||
| // 				"(bad[regex",
 | ||||
| // 			},
 | ||||
| // 			skipAuthRoutes: []string{},
 | ||||
| // 			expectedRoutes: []expectedAllowedRoute{},
 | ||||
| // 			shouldError:    true,
 | ||||
| // 		},
 | ||||
| // 		{
 | ||||
| // 			name:          "Invalid skipAuthRoutes entry",
 | ||||
| // 			skipAuthRegex: []string{},
 | ||||
| // 			skipAuthRoutes: []string{
 | ||||
| // 				"GET=^/foo/bar",
 | ||||
| // 				"POST=^/baz/[0-9]+/thing",
 | ||||
| // 				"^/all/methods$",
 | ||||
| // 				"PUT=(bad[regex",
 | ||||
| // 			},
 | ||||
| // 			expectedRoutes: []expectedAllowedRoute{},
 | ||||
| // 			shouldError:    true,
 | ||||
| // 		},
 | ||||
| // 	}
 | ||||
| //
 | ||||
| // 	for _, tc := range testCases {
 | ||||
| // 		t.Run(tc.name, func(t *testing.T) {
 | ||||
| // 			opts := &options.Options{
 | ||||
| // 				SkipAuthRegex:  tc.skipAuthRegex,
 | ||||
| // 				SkipAuthRoutes: tc.skipAuthRoutes,
 | ||||
| // 			}
 | ||||
| // 			routes, err := buildRoutesAllowlist(opts)
 | ||||
| // 			if tc.shouldError {
 | ||||
| // 				assert.Error(t, err)
 | ||||
| // 				return
 | ||||
| // 			}
 | ||||
| // 			assert.NoError(t, err)
 | ||||
| //
 | ||||
| // 			for i, route := range routes {
 | ||||
| // 				assert.Greater(t, len(tc.expectedRoutes), i)
 | ||||
| // 				assert.Equal(t, route.method, tc.expectedRoutes[i].method)
 | ||||
| // 				assert.Equal(t, route.pathRegex.String(), tc.expectedRoutes[i].regexString)
 | ||||
| // 			}
 | ||||
| // 		})
 | ||||
| // 	}
 | ||||
| // }
 | ||||
| 
 | ||||
| func TestAllowedRequest(t *testing.T) { | ||||
| 	upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
|  | @ -2356,12 +2378,20 @@ func TestAllowedRequest(t *testing.T) { | |||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	opts.SkipAuthRegex = []string{ | ||||
| 		"^/skip/auth/regex$", | ||||
| 	} | ||||
| 	opts.SkipAuthRoutes = []string{ | ||||
| 		"GET=^/skip/auth/routes/get", | ||||
| 	opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||
| 		{ | ||||
| 			ID:     "regex", | ||||
| 			Path:   "^/skip/auth/regex$", | ||||
| 			Policy: options.AllowPolicy, | ||||
| 		}, | ||||
| 		{ | ||||
| 			ID:      "route", | ||||
| 			Path:    "^/skip/auth/routes/get", | ||||
| 			Methods: []string{http.MethodGet}, | ||||
| 			Policy:  options.AllowPolicy, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	err := validation.Validate(opts) | ||||
| 	assert.NoError(t, err) | ||||
| 	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) | ||||
|  | @ -2417,7 +2447,6 @@ func TestAllowedRequest(t *testing.T) { | |||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			req, err := http.NewRequest(tc.method, tc.url, nil) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, tc.allowed, proxy.isAllowedRoute(req)) | ||||
| 
 | ||||
| 			rw := httptest.NewRecorder() | ||||
| 			proxy.ServeHTTP(rw, req) | ||||
|  | @ -2670,8 +2699,18 @@ func TestAuthOnlyAllowedGroupsWithSkipMethods(t *testing.T) { | |||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			test, err := NewAuthOnlyEndpointTest("?allowed_groups=a,b", func(opts *options.Options) { | ||||
| 				opts.SkipAuthPreflight = true | ||||
| 				opts.TrustedIPs = []string{"1.2.3.4"} | ||||
| 				opts.Authorization.RequestRules = []options.AuthorizationRule{ | ||||
| 					{ | ||||
| 						ID:      "skip-auth-preflight", | ||||
| 						Methods: []string{http.MethodOptions}, | ||||
| 						Policy:  options.AllowPolicy, | ||||
| 					}, | ||||
| 					{ | ||||
| 						ID:     "trusted-ips", | ||||
| 						IPs:    []string{"1.2.3.4"}, | ||||
| 						Policy: options.AllowPolicy, | ||||
| 					}, | ||||
| 				} | ||||
| 			}) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
|  |  | |||
|  | @ -4,32 +4,24 @@ import ( | |||
| 	"net" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| ) | ||||
| 
 | ||||
| type AuthorizationPolicy int | ||||
| 
 | ||||
| const ( | ||||
| 	NonePolicy AuthorizationPolicy = iota | ||||
| 	AllowPolicy | ||||
| 	DelegatePolicy | ||||
| 	DenyPolicy | ||||
| ) | ||||
| 
 | ||||
| type RuleSet interface { | ||||
| 	MatchesRequest(req *http.Request) AuthorizationPolicy | ||||
| 	MatchesRequest(req *http.Request) middlewareapi.AuthorizationPolicy | ||||
| } | ||||
| 
 | ||||
| type rule struct { | ||||
| 	conditions []condition | ||||
| 	policy     AuthorizationPolicy | ||||
| 	policy     middlewareapi.AuthorizationPolicy | ||||
| } | ||||
| 
 | ||||
| func (r rule) matches(req *http.Request) AuthorizationPolicy { | ||||
| func (r rule) matches(req *http.Request) middlewareapi.AuthorizationPolicy { | ||||
| 	for _, condition := range r.conditions { | ||||
| 		if !condition.matches(req) { | ||||
| 			// One of the conditions didn't match so this rule does not apply
 | ||||
| 			return NonePolicy | ||||
| 			return middlewareapi.OmittedPolicy | ||||
| 		} | ||||
| 	} | ||||
| 	// If all conditions match, return the configured rule policy
 | ||||
|  | @ -60,17 +52,17 @@ func newRule(authRule options.AuthorizationRule, getClientIPFunc func(*http.Requ | |||
| 		conditions = append(conditions, condition) | ||||
| 	} | ||||
| 
 | ||||
| 	var policy AuthorizationPolicy | ||||
| 	var policy middlewareapi.AuthorizationPolicy | ||||
| 	switch authRule.Policy { | ||||
| 	case options.AllowPolicy: | ||||
| 		policy = AllowPolicy | ||||
| 		policy = middlewareapi.AllowPolicy | ||||
| 	case options.DelegatePolicy: | ||||
| 		policy = DelegatePolicy | ||||
| 		policy = middlewareapi.DelegatePolicy | ||||
| 	case options.DenyPolicy: | ||||
| 		policy = DenyPolicy | ||||
| 		policy = middlewareapi.DenyPolicy | ||||
| 	default: | ||||
| 		// This shouldn't be the case and should be prevented by validation
 | ||||
| 		policy = NonePolicy | ||||
| 		policy = middlewareapi.OmittedPolicy | ||||
| 	} | ||||
| 
 | ||||
| 	return rule{ | ||||
|  | @ -83,15 +75,15 @@ type ruleSet struct { | |||
| 	rules []rule | ||||
| } | ||||
| 
 | ||||
| func (r ruleSet) MatchesRequest(req *http.Request) AuthorizationPolicy { | ||||
| func (r ruleSet) MatchesRequest(req *http.Request) middlewareapi.AuthorizationPolicy { | ||||
| 	for _, rule := range r.rules { | ||||
| 		if policy := rule.matches(req); policy != NonePolicy { | ||||
| 		if policy := rule.matches(req); policy != middlewareapi.OmittedPolicy { | ||||
| 			// The rule applies to this request, return its policy
 | ||||
| 			return policy | ||||
| 		} | ||||
| 	} | ||||
| 	// No rules matched
 | ||||
| 	return NonePolicy | ||||
| 	return middlewareapi.OmittedPolicy | ||||
| } | ||||
| 
 | ||||
| func NewRuleSet(requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (RuleSet, error) { | ||||
|  |  | |||
|  | @ -6,10 +6,11 @@ import ( | |||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| ) | ||||
| 
 | ||||
| var result AuthorizationPolicy | ||||
| var result middlewareapi.AuthorizationPolicy | ||||
| 
 | ||||
| func benchmarkRuleSetMatches(ruleCount int, b *testing.B) { | ||||
| 	rule1 := options.AuthorizationRule{ | ||||
|  | @ -53,10 +54,10 @@ func benchmarkRuleSetMatches(ruleCount int, b *testing.B) { | |||
| 
 | ||||
| 	req := httptest.NewRequest("GET", "/foo/bar/baz", nil) | ||||
| 
 | ||||
| 	var r AuthorizationPolicy | ||||
| 	var r middlewareapi.AuthorizationPolicy | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		r = ruleSet.MatchesRequest(req) | ||||
| 		if r != NonePolicy { | ||||
| 		if r != middlewareapi.OmittedPolicy { | ||||
| 			b.Fatal("expected policy not to match") | ||||
| 		} | ||||
| 	} | ||||
|  |  | |||
|  | @ -0,0 +1,61 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authorization" | ||||
| ) | ||||
| 
 | ||||
| func NewRequestAuthorization(writer pagewriter.Writer, requestRules []options.AuthorizationRule, getClientIPFunc func(*http.Request) net.IP) (alice.Constructor, error) { | ||||
| 	ruleset, err := authorization.NewRuleSet(requestRules, getClientIPFunc) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not initialise ruleset: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	ra := &requestAuthorizer{ | ||||
| 		ruleset: ruleset, | ||||
| 		writer:  writer, | ||||
| 	} | ||||
| 
 | ||||
| 	return ra.checkRequestAuthorization, nil | ||||
| } | ||||
| 
 | ||||
| type requestAuthorizer struct { | ||||
| 	ruleset authorization.RuleSet | ||||
| 	writer  pagewriter.Writer | ||||
| } | ||||
| 
 | ||||
| func (r *requestAuthorizer) checkRequestAuthorization(next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := middlewareapi.GetRequestScope(req) | ||||
| 		// If scope is nil, this will panic.
 | ||||
| 		// A scope should always be injected before this handler is called.
 | ||||
| 		if scope.Authorization.Policy != middlewareapi.OmittedPolicy { | ||||
| 			// The request was already authorized, pass to the next handler
 | ||||
| 			next.ServeHTTP(rw, req) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		policy := r.ruleset.MatchesRequest(req) | ||||
| 		switch policy { | ||||
| 		case middlewareapi.AllowPolicy, middlewareapi.DelegatePolicy: | ||||
| 			scope.Authorization.Type = middlewareapi.RequestAuthorization | ||||
| 			scope.Authorization.Policy = policy | ||||
| 		case middlewareapi.DenyPolicy: | ||||
| 			r.writer.WriteErrorPage(rw, pagewriter.ErrorPageOpts{ | ||||
| 				Status:    http.StatusForbidden, | ||||
| 				RequestID: scope.RequestID, | ||||
| 				AppError:  "Request denied by authorization policy", | ||||
| 				Messages:  []interface{}{"Request denied by authorization policy"}, | ||||
| 			}) | ||||
| 		} | ||||
| 
 | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
|  | @ -1,70 +0,0 @@ | |||
| package validation | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" | ||||
| ) | ||||
| 
 | ||||
| func validateAllowlists(o *options.Options) []string { | ||||
| 	msgs := []string{} | ||||
| 
 | ||||
| 	msgs = append(msgs, validateRoutes(o)...) | ||||
| 	msgs = append(msgs, validateRegexes(o)...) | ||||
| 	msgs = append(msgs, validateTrustedIPs(o)...) | ||||
| 
 | ||||
| 	if len(o.TrustedIPs) > 0 && o.ReverseProxy { | ||||
| 		_, err := fmt.Fprintln(os.Stderr, "WARNING: mixing --trusted-ip with --reverse-proxy is a potential security vulnerability. An attacker can inject a trusted IP into an X-Real-IP or X-Forwarded-For header if they aren't properly protected outside of oauth2-proxy") | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return msgs | ||||
| } | ||||
| 
 | ||||
| // validateRoutes validates method=path routes passed with options.SkipAuthRoutes
 | ||||
| func validateRoutes(o *options.Options) []string { | ||||
| 	msgs := []string{} | ||||
| 	for _, route := range o.SkipAuthRoutes { | ||||
| 		var regex string | ||||
| 		parts := strings.SplitN(route, "=", 2) | ||||
| 		if len(parts) == 1 { | ||||
| 			regex = parts[0] | ||||
| 		} else { | ||||
| 			regex = parts[1] | ||||
| 		} | ||||
| 		_, err := regexp.Compile(regex) | ||||
| 		if err != nil { | ||||
| 			msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err)) | ||||
| 		} | ||||
| 	} | ||||
| 	return msgs | ||||
| } | ||||
| 
 | ||||
| // validateRegex validates regex paths passed with options.SkipAuthRegex
 | ||||
| func validateRegexes(o *options.Options) []string { | ||||
| 	msgs := []string{} | ||||
| 	for _, regex := range o.SkipAuthRegex { | ||||
| 		_, err := regexp.Compile(regex) | ||||
| 		if err != nil { | ||||
| 			msgs = append(msgs, fmt.Sprintf("error compiling regex /%s/: %v", regex, err)) | ||||
| 		} | ||||
| 	} | ||||
| 	return msgs | ||||
| } | ||||
| 
 | ||||
| // validateTrustedIPs validates IP/CIDRs for IP based allowlists
 | ||||
| func validateTrustedIPs(o *options.Options) []string { | ||||
| 	msgs := []string{} | ||||
| 	for i, ipStr := range o.TrustedIPs { | ||||
| 		if nil == ip.ParseIPNet(ipStr) { | ||||
| 			msgs = append(msgs, fmt.Sprintf("trusted_ips[%d] (%s) could not be recognized", i, ipStr)) | ||||
| 		} | ||||
| 	} | ||||
| 	return msgs | ||||
| } | ||||
|  | @ -1,125 +1,125 @@ | |||
| package validation | ||||
| 
 | ||||
| import ( | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Allowlist", func() { | ||||
| 	type validateRoutesTableInput struct { | ||||
| 		routes     []string | ||||
| 		errStrings []string | ||||
| 	} | ||||
| 
 | ||||
| 	type validateRegexesTableInput struct { | ||||
| 		regexes    []string | ||||
| 		errStrings []string | ||||
| 	} | ||||
| 
 | ||||
| 	type validateTrustedIPsTableInput struct { | ||||
| 		trustedIPs []string | ||||
| 		errStrings []string | ||||
| 	} | ||||
| 
 | ||||
| 	DescribeTable("validateRoutes", | ||||
| 		func(r *validateRoutesTableInput) { | ||||
| 			opts := &options.Options{ | ||||
| 				SkipAuthRoutes: r.routes, | ||||
| 			} | ||||
| 			Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings)) | ||||
| 		}, | ||||
| 		Entry("Valid regex routes", &validateRoutesTableInput{ | ||||
| 			routes: []string{ | ||||
| 				"/foo", | ||||
| 				"POST=/foo/bar", | ||||
| 				"PUT=^/foo/bar$", | ||||
| 				"DELETE=/crazy/(?:regex)?/[^/]+/stuff$", | ||||
| 			}, | ||||
| 			errStrings: []string{}, | ||||
| 		}), | ||||
| 		Entry("Bad regexes do not compile", &validateRoutesTableInput{ | ||||
| 			routes: []string{ | ||||
| 				"POST=/(foo", | ||||
| 				"OPTIONS=/foo/bar)", | ||||
| 				"GET=^]/foo/bar[$", | ||||
| 				"GET=^]/foo/bar[$", | ||||
| 			}, | ||||
| 			errStrings: []string{ | ||||
| 				"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", | ||||
| 				"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", | ||||
| 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", | ||||
| 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", | ||||
| 			}, | ||||
| 		}), | ||||
| 	) | ||||
| 
 | ||||
| 	DescribeTable("validateRegexes", | ||||
| 		func(r *validateRegexesTableInput) { | ||||
| 			opts := &options.Options{ | ||||
| 				SkipAuthRegex: r.regexes, | ||||
| 			} | ||||
| 			Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings)) | ||||
| 		}, | ||||
| 		Entry("Valid regex routes", &validateRegexesTableInput{ | ||||
| 			regexes: []string{ | ||||
| 				"/foo", | ||||
| 				"/foo/bar", | ||||
| 				"^/foo/bar$", | ||||
| 				"/crazy/(?:regex)?/[^/]+/stuff$", | ||||
| 			}, | ||||
| 			errStrings: []string{}, | ||||
| 		}), | ||||
| 		Entry("Bad regexes do not compile", &validateRegexesTableInput{ | ||||
| 			regexes: []string{ | ||||
| 				"/(foo", | ||||
| 				"/foo/bar)", | ||||
| 				"^]/foo/bar[$", | ||||
| 				"^]/foo/bar[$", | ||||
| 			}, | ||||
| 			errStrings: []string{ | ||||
| 				"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`", | ||||
| 				"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`", | ||||
| 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", | ||||
| 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`", | ||||
| 			}, | ||||
| 		}), | ||||
| 	) | ||||
| 
 | ||||
| 	DescribeTable("validateTrustedIPs", | ||||
| 		func(t *validateTrustedIPsTableInput) { | ||||
| 			opts := &options.Options{ | ||||
| 				TrustedIPs: t.trustedIPs, | ||||
| 			} | ||||
| 			Expect(validateTrustedIPs(opts)).To(ConsistOf(t.errStrings)) | ||||
| 		}, | ||||
| 		Entry("Non-overlapping valid IPs", &validateTrustedIPsTableInput{ | ||||
| 			trustedIPs: []string{ | ||||
| 				"127.0.0.1", | ||||
| 				"10.32.0.1/32", | ||||
| 				"43.36.201.0/24", | ||||
| 				"::1", | ||||
| 				"2a12:105:ee7:9234:0:0:0:0/64", | ||||
| 			}, | ||||
| 			errStrings: []string{}, | ||||
| 		}), | ||||
| 		Entry("Overlapping valid IPs", &validateTrustedIPsTableInput{ | ||||
| 			trustedIPs: []string{ | ||||
| 				"135.180.78.199", | ||||
| 				"135.180.78.199/32", | ||||
| 				"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4", | ||||
| 				"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4/128", | ||||
| 			}, | ||||
| 			errStrings: []string{}, | ||||
| 		}), | ||||
| 		Entry("Invalid IPs", &validateTrustedIPsTableInput{ | ||||
| 			trustedIPs: []string{"[::1]", "alkwlkbn/32"}, | ||||
| 			errStrings: []string{ | ||||
| 				"trusted_ips[0] ([::1]) could not be recognized", | ||||
| 				"trusted_ips[1] (alkwlkbn/32) could not be recognized", | ||||
| 			}, | ||||
| 		}), | ||||
| 	) | ||||
| }) | ||||
| // import (
 | ||||
| // 	. "github.com/onsi/ginkgo"
 | ||||
| // 	. "github.com/onsi/ginkgo/extensions/table"
 | ||||
| // 	. "github.com/onsi/gomega"
 | ||||
| //
 | ||||
| // 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | ||||
| // )
 | ||||
| //
 | ||||
| // var _ = Describe("Allowlist", func() {
 | ||||
| // 	type validateRoutesTableInput struct {
 | ||||
| // 		routes     []string
 | ||||
| // 		errStrings []string
 | ||||
| // 	}
 | ||||
| //
 | ||||
| // 	type validateRegexesTableInput struct {
 | ||||
| // 		regexes    []string
 | ||||
| // 		errStrings []string
 | ||||
| // 	}
 | ||||
| //
 | ||||
| // 	type validateTrustedIPsTableInput struct {
 | ||||
| // 		trustedIPs []string
 | ||||
| // 		errStrings []string
 | ||||
| // 	}
 | ||||
| //
 | ||||
| // 	DescribeTable("validateRoutes",
 | ||||
| // 		func(r *validateRoutesTableInput) {
 | ||||
| // 			opts := &options.Options{
 | ||||
| // 				SkipAuthRoutes: r.routes,
 | ||||
| // 			}
 | ||||
| // 			Expect(validateRoutes(opts)).To(ConsistOf(r.errStrings))
 | ||||
| // 		},
 | ||||
| // 		Entry("Valid regex routes", &validateRoutesTableInput{
 | ||||
| // 			routes: []string{
 | ||||
| // 				"/foo",
 | ||||
| // 				"POST=/foo/bar",
 | ||||
| // 				"PUT=^/foo/bar$",
 | ||||
| // 				"DELETE=/crazy/(?:regex)?/[^/]+/stuff$",
 | ||||
| // 			},
 | ||||
| // 			errStrings: []string{},
 | ||||
| // 		}),
 | ||||
| // 		Entry("Bad regexes do not compile", &validateRoutesTableInput{
 | ||||
| // 			routes: []string{
 | ||||
| // 				"POST=/(foo",
 | ||||
| // 				"OPTIONS=/foo/bar)",
 | ||||
| // 				"GET=^]/foo/bar[$",
 | ||||
| // 				"GET=^]/foo/bar[$",
 | ||||
| // 			},
 | ||||
| // 			errStrings: []string{
 | ||||
| // 				"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`",
 | ||||
| // 				"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`",
 | ||||
| // 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
 | ||||
| // 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
 | ||||
| // 			},
 | ||||
| // 		}),
 | ||||
| // 	)
 | ||||
| //
 | ||||
| // 	DescribeTable("validateRegexes",
 | ||||
| // 		func(r *validateRegexesTableInput) {
 | ||||
| // 			opts := &options.Options{
 | ||||
| // 				SkipAuthRegex: r.regexes,
 | ||||
| // 			}
 | ||||
| // 			Expect(validateRegexes(opts)).To(ConsistOf(r.errStrings))
 | ||||
| // 		},
 | ||||
| // 		Entry("Valid regex routes", &validateRegexesTableInput{
 | ||||
| // 			regexes: []string{
 | ||||
| // 				"/foo",
 | ||||
| // 				"/foo/bar",
 | ||||
| // 				"^/foo/bar$",
 | ||||
| // 				"/crazy/(?:regex)?/[^/]+/stuff$",
 | ||||
| // 			},
 | ||||
| // 			errStrings: []string{},
 | ||||
| // 		}),
 | ||||
| // 		Entry("Bad regexes do not compile", &validateRegexesTableInput{
 | ||||
| // 			regexes: []string{
 | ||||
| // 				"/(foo",
 | ||||
| // 				"/foo/bar)",
 | ||||
| // 				"^]/foo/bar[$",
 | ||||
| // 				"^]/foo/bar[$",
 | ||||
| // 			},
 | ||||
| // 			errStrings: []string{
 | ||||
| // 				"error compiling regex //(foo/: error parsing regexp: missing closing ): `/(foo`",
 | ||||
| // 				"error compiling regex //foo/bar)/: error parsing regexp: unexpected ): `/foo/bar)`",
 | ||||
| // 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
 | ||||
| // 				"error compiling regex /^]/foo/bar[$/: error parsing regexp: missing closing ]: `[$`",
 | ||||
| // 			},
 | ||||
| // 		}),
 | ||||
| // 	)
 | ||||
| //
 | ||||
| // 	DescribeTable("validateTrustedIPs",
 | ||||
| // 		func(t *validateTrustedIPsTableInput) {
 | ||||
| // 			opts := &options.Options{
 | ||||
| // 				TrustedIPs: t.trustedIPs,
 | ||||
| // 			}
 | ||||
| // 			Expect(validateTrustedIPs(opts)).To(ConsistOf(t.errStrings))
 | ||||
| // 		},
 | ||||
| // 		Entry("Non-overlapping valid IPs", &validateTrustedIPsTableInput{
 | ||||
| // 			trustedIPs: []string{
 | ||||
| // 				"127.0.0.1",
 | ||||
| // 				"10.32.0.1/32",
 | ||||
| // 				"43.36.201.0/24",
 | ||||
| // 				"::1",
 | ||||
| // 				"2a12:105:ee7:9234:0:0:0:0/64",
 | ||||
| // 			},
 | ||||
| // 			errStrings: []string{},
 | ||||
| // 		}),
 | ||||
| // 		Entry("Overlapping valid IPs", &validateTrustedIPsTableInput{
 | ||||
| // 			trustedIPs: []string{
 | ||||
| // 				"135.180.78.199",
 | ||||
| // 				"135.180.78.199/32",
 | ||||
| // 				"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4",
 | ||||
| // 				"d910:a5a1:16f8:ddf5:e5b9:5cef:a65e:41f4/128",
 | ||||
| // 			},
 | ||||
| // 			errStrings: []string{},
 | ||||
| // 		}),
 | ||||
| // 		Entry("Invalid IPs", &validateTrustedIPsTableInput{
 | ||||
| // 			trustedIPs: []string{"[::1]", "alkwlkbn/32"},
 | ||||
| // 			errStrings: []string{
 | ||||
| // 				"trusted_ips[0] ([::1]) could not be recognized",
 | ||||
| // 				"trusted_ips[1] (alkwlkbn/32) could not be recognized",
 | ||||
| // 			},
 | ||||
| // 		}),
 | ||||
| // 	)
 | ||||
| // })
 | ||||
|  |  | |||
|  | @ -0,0 +1,94 @@ | |||
| package validation | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"regexp" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" | ||||
| ) | ||||
| 
 | ||||
| func validateAuthorization(authorization options.Authorization, reverseProxy bool) []string { | ||||
| 	msgs := []string{} | ||||
| 
 | ||||
| 	msgs = append(msgs, validateRequestRules(authorization.RequestRules, reverseProxy)...) | ||||
| 
 | ||||
| 	return msgs | ||||
| } | ||||
| 
 | ||||
| func validateRequestRules(rules []options.AuthorizationRule, reverseProxy bool) []string { | ||||
| 	msgs := []string{} | ||||
| 
 | ||||
| 	ids := make(map[string]struct{}) | ||||
| 
 | ||||
| 	for _, rule := range rules { | ||||
| 		msgs = append(msgs, validateRequestRule(ids, rule, reverseProxy)...) | ||||
| 	} | ||||
| 
 | ||||
| 	return msgs | ||||
| } | ||||
| 
 | ||||
| func validateRequestRule(ids map[string]struct{}, rule options.AuthorizationRule, reverseProxy bool) []string { | ||||
| 	msgs := []string{} | ||||
| 
 | ||||
| 	if rule.ID == "" { | ||||
| 		msgs = append(msgs, "request rule has empty ID: IDs are required for all request rules") | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := ids[rule.ID]; ok { | ||||
| 		msgs = append(msgs, fmt.Sprintf("multiple request rules found with ID %q: request rule IDs must be unique", rule.ID)) | ||||
| 	} | ||||
| 	ids[rule.ID] = struct{}{} | ||||
| 
 | ||||
| 	msgs = append(msgs, validateRequestRulePolicy(rule.ID, rule.Policy)...) | ||||
| 	msgs = append(msgs, validateRequestRulePath(rule.ID, rule.Path)...) | ||||
| 	msgs = append(msgs, validateRequestRuleIPs(rule.ID, rule.IPs, reverseProxy)...) | ||||
| 
 | ||||
| 	return msgs | ||||
| } | ||||
| 
 | ||||
| func validateRequestRulePolicy(ruleID string, policy options.AuthorizationPolicy) []string { | ||||
| 	msgs := []string{} | ||||
| 
 | ||||
| 	switch policy { | ||||
| 	case options.AllowPolicy, options.DenyPolicy, options.DelegatePolicy: | ||||
| 		// Do nothing for valid options
 | ||||
| 	default: | ||||
| 		msgs = append(msgs, fmt.Sprintf("request rule %q has invalid policy (%s): policy must be one of %s, %s or %s", ruleID, policy, options.AllowPolicy, options.DenyPolicy, options.DelegatePolicy)) | ||||
| 	} | ||||
| 
 | ||||
| 	return msgs | ||||
| } | ||||
| 
 | ||||
| // validateRequestRulePath validates paths for path/regex based conditions
 | ||||
| func validateRequestRulePath(ruleID string, path string) []string { | ||||
| 	msgs := []string{} | ||||
| 
 | ||||
| 	_, err := regexp.Compile(path) | ||||
| 	if err != nil { | ||||
| 		msgs = append(msgs, fmt.Sprintf("error compiling path regex (%s) for rule %q: %v", path, ruleID, err)) | ||||
| 	} | ||||
| 
 | ||||
| 	return msgs | ||||
| } | ||||
| 
 | ||||
| // validateRequestRuleIPs validates IP/CIDRs for IP based conditions.
 | ||||
| func validateRequestRuleIPs(ruleID string, ips []string, reverseProxy bool) []string { | ||||
| 	msgs := []string{} | ||||
| 
 | ||||
| 	if len(ips) > 0 && reverseProxy { | ||||
| 		_, err := fmt.Fprintln(os.Stderr, "WARNING: mixing IP authorization with --reverse-proxy is a potential security vulnerability. An attacker can inject a trusted IP into an X-Real-IP or X-Forwarded-For header if they aren't properly protected outside of oauth2-proxy") | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for i, ipStr := range ips { | ||||
| 		if nil == ip.ParseIPNet(ipStr) { | ||||
| 			msgs = append(msgs, fmt.Sprintf("rule %q IP [%d] (%s) could not be recognized", ruleID, i, ipStr)) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return msgs | ||||
| } | ||||
|  | @ -20,6 +20,7 @@ import ( | |||
| // are of the correct format
 | ||||
| func Validate(o *options.Options) error { | ||||
| 	msgs := validateCookie(o.Cookie) | ||||
| 	msgs = append(msgs, validateAuthorization(o.Authorization, o.ReverseProxy)...) | ||||
| 	msgs = append(msgs, validateSessionCookieMinimal(o)...) | ||||
| 	msgs = append(msgs, validateRedisSessionStore(o)...) | ||||
| 	msgs = append(msgs, prefixValues("injectRequestHeaders: ", validateHeaders(o.InjectRequestHeaders)...)...) | ||||
|  | @ -96,9 +97,6 @@ func Validate(o *options.Options) error { | |||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	// Do this after ReverseProxy validation for TrustedIP coordinated checks
 | ||||
| 	msgs = append(msgs, validateAllowlists(o)...) | ||||
| 
 | ||||
| 	if len(msgs) != 0 { | ||||
| 		return fmt.Errorf("invalid configuration:\n  %s", | ||||
| 			strings.Join(msgs, "\n  ")) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue