From ceb9a387b12b06230e4e03fef67edbb4931bc1a1 Mon Sep 17 00:00:00 2001 From: Jan Larwig Date: Fri, 31 Oct 2025 16:11:54 +0100 Subject: [PATCH] deref everything... but why? Signed-off-by: Jan Larwig --- main.go | 7 +++++-- main_test.go | 17 ++++++++++++----- pkg/middleware/headers.go | 3 ++- pkg/upstream/http.go | 8 ++++---- pkg/upstream/proxy.go | 7 ++++--- pkg/upstream/static.go | 13 ++----------- pkg/validation/options.go | 3 ++- pkg/validation/providers.go | 11 ++++++----- pkg/validation/upstreams.go | 17 +++++++++-------- providers/adfs.go | 3 ++- providers/google.go | 9 ++++----- providers/ms_entra_id.go | 3 ++- providers/oidc.go | 3 ++- providers/providers.go | 9 +++++---- 14 files changed, 61 insertions(+), 52 deletions(-) diff --git a/main.go b/main.go index 7e18b95f..42e8bab0 100644 --- a/main.go +++ b/main.go @@ -75,11 +75,14 @@ func loadConfiguration(config, yamlConfig string, extraFlags *pflag.FlagSet, arg if yamlConfig != "" { logger.Printf("WARNING: You are using alpha configuration. The structure in this configuration file may change without notice. You MUST remove conflicting options from your existing configuration.") - return loadYamlOptions(yamlConfig, config, extraFlags, args) + opts, err = loadYamlOptions(yamlConfig, config, extraFlags, args) + if err != nil { + return nil, fmt.Errorf("failed to load yaml options: %w", err) + } } + // Ensure defaults after loading configuration opts.EnsureDefaults() - return opts, nil } diff --git a/main_test.go b/main_test.go index 0323c838..1d9102d7 100644 --- a/main_test.go +++ b/main_test.go @@ -123,6 +123,7 @@ redirect_url="http://localhost:4180/oauth2/callback" opts.RawRedirectURL = "http://localhost:4180/oauth2/callback" opts.UpstreamServers = options.UpstreamConfig{ + ProxyRawPath: ptr.Ptr(false), Upstreams: []options.Upstream{ { ID: "/", @@ -132,6 +133,7 @@ redirect_url="http://localhost:4180/oauth2/callback" PassHostHeader: ptr.Ptr(true), ProxyWebSockets: ptr.Ptr(true), Timeout: ptr.Ptr(options.DefaultUpstreamTimeout), + Static: ptr.Ptr(false), InsecureSkipTLSVerify: ptr.Ptr(false), DisableKeepAlives: ptr.Ptr(false), }, @@ -139,7 +141,8 @@ redirect_url="http://localhost:4180/oauth2/callback" } authHeader := options.Header{ - Name: "Authorization", + Name: "Authorization", + PreserveRequestValue: ptr.Ptr(false), Values: []options.HeaderValue{ { ClaimSource: &options.ClaimSource{ @@ -153,10 +156,7 @@ redirect_url="http://localhost:4180/oauth2/callback" }, } - authHeader.PreserveRequestValue = ptr.Ptr(false) opts.InjectRequestHeaders = append([]options.Header{authHeader}, opts.InjectRequestHeaders...) - - authHeader.PreserveRequestValue = nil opts.InjectResponseHeaders = append(opts.InjectResponseHeaders, authHeader) opts.Providers = options.Providers{ @@ -186,6 +186,12 @@ redirect_url="http://localhost:4180/oauth2/callback" InsecureSkipIssuerVerification: ptr.Ptr(false), SkipDiscovery: ptr.Ptr(false), }, + MicrosoftEntraIDConfig: options.MicrosoftEntraIDOptions{ + FederatedTokenAuth: ptr.Ptr(false), + }, + ADFSConfig: options.ADFSOptions{ + SkipScope: ptr.Ptr(false), + }, LoginURLParameters: []options.LoginURLParameter{ {Name: "approval_prompt", Default: []string{"force"}}, }, @@ -254,7 +260,8 @@ redirect_url="http://localhost:4180/oauth2/callback" Expect(err).ToNot(HaveOccurred()) } Expect(in.expectedOptions).ToNot(BeNil()) - Expect(opts).To(EqualOpts(in.expectedOptions())) + expectedOpts := in.expectedOptions() + Expect(opts).To(EqualOpts(expectedOpts)) }, Entry("with legacy configuration", loadConfigurationTableInput{ configContent: testCoreConfig + testLegacyConfig, diff --git a/pkg/middleware/headers.go b/pkg/middleware/headers.go index d9287505..ed82a30b 100644 --- a/pkg/middleware/headers.go +++ b/pkg/middleware/headers.go @@ -9,6 +9,7 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" ) func NewRequestHeaderInjector(headers []options.Header) (alice.Constructor, error) { @@ -27,7 +28,7 @@ func NewRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro func newStripHeaders(headers []options.Header) alice.Constructor { headersToStrip := []options.Header{} for _, header := range headers { - if !(*header.PreserveRequestValue) { + if !ptr.Deref(header.PreserveRequestValue, false) { headersToStrip = append(headersToStrip, header) } } diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go index 59580ee3..9c33f96c 100644 --- a/pkg/upstream/http.go +++ b/pkg/upstream/http.go @@ -54,7 +54,7 @@ func newHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *option // Set up a WebSocket proxy if required var wsProxy http.Handler - if *upstream.ProxyWebSockets { + if ptr.Deref(upstream.ProxyWebSockets, false) { wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify) } @@ -150,14 +150,14 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr // InsecureSkipVerify is a configurable option we allow /* #nosec G402 */ - if *upstream.InsecureSkipTLSVerify { + if ptr.Deref(upstream.InsecureSkipTLSVerify, false) { transport.TLSClientConfig.InsecureSkipVerify = true } // Ensure we always pass the original request path setProxyDirector(proxy) - if upstream.PassHostHeader != nil && !(*upstream.PassHostHeader) { + if !ptr.Deref(upstream.PassHostHeader, false) { setProxyUpstreamHostHeader(proxy, target) } @@ -169,7 +169,7 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr // Pass on DisableKeepAlives to the transport settings // to allow for disabling HTTP keep-alive connections - transport.DisableKeepAlives = *upstream.DisableKeepAlives + transport.DisableKeepAlives = ptr.Deref(upstream.DisableKeepAlives, false) // Apply the customized transport to our proxy before returning it proxy.Transport = transport diff --git a/pkg/upstream/proxy.go b/pkg/upstream/proxy.go index 0d2286ea..acf24d1a 100644 --- a/pkg/upstream/proxy.go +++ b/pkg/upstream/proxy.go @@ -14,6 +14,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" ) // ProxyErrorHandler is a function that will be used to render error pages when @@ -27,12 +28,12 @@ func NewProxy(upstreams options.UpstreamConfig, sigData *options.SignatureData, serveMux: mux.NewRouter(), } - if *upstreams.ProxyRawPath { + if ptr.Deref(upstreams.ProxyRawPath, false) { m.serveMux.UseEncodedPath() } for _, upstream := range sortByPathLongest(upstreams.Upstreams) { - if *upstream.Static { + if ptr.Deref(upstream.Static, false) { if err := m.registerStaticResponseHandler(upstream, writer); err != nil { return nil, fmt.Errorf("could not register static upstream %q: %v", upstream.ID, err) } @@ -74,7 +75,7 @@ func (m *multiUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request // registerStaticResponseHandler registers a static response handler with at the given path. func (m *multiUpstreamProxy) registerStaticResponseHandler(upstream options.Upstream, writer pagewriter.Writer) error { - logger.Printf("mapping path %q => static response %d", upstream.Path, derefStaticCode(upstream.StaticCode)) + logger.Printf("mapping path %q => static response %d", upstream.Path, ptr.Deref(upstream.StaticCode, 200)) return m.registerHandler(upstream, newStaticResponseHandler(upstream.ID, upstream.StaticCode), writer) } diff --git a/pkg/upstream/static.go b/pkg/upstream/static.go index 027f3e74..d7d037bf 100644 --- a/pkg/upstream/static.go +++ b/pkg/upstream/static.go @@ -6,15 +6,14 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" ) -const defaultStaticResponseCode = 200 - // newStaticResponseHandler creates a new staticResponseHandler that serves a // a static response code. func newStaticResponseHandler(upstream string, code *int) http.Handler { return &staticResponseHandler{ - code: derefStaticCode(code), + code: ptr.Deref(code, 200), upstream: upstream, } } @@ -38,11 +37,3 @@ func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Requ logger.Errorf("Error writing static response: %v", err) } } - -// derefStaticCode returns the derefenced value, or the default if the value is nil -func derefStaticCode(code *int) int { - if code != nil { - return *code - } - return defaultStaticResponseCode -} diff --git a/pkg/validation/options.go b/pkg/validation/options.go index d5aba4e4..ffb0accc 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -15,6 +15,7 @@ import ( internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" ) // Validate checks that required options are set and validates those that they @@ -34,7 +35,7 @@ func Validate(o *options.Options) error { transport := requests.DefaultTransport.(*http.Transport) transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 -- InsecureSkipVerify is a configurable option we allow } else if len(o.Providers[0].CAFiles) > 0 { - pool, err := util.GetCertPool(o.Providers[0].CAFiles, *o.Providers[0].UseSystemTrustStore) + pool, err := util.GetCertPool(o.Providers[0].CAFiles, ptr.Deref(o.Providers[0].UseSystemTrustStore, false)) if err == nil { transport := requests.DefaultTransport.(*http.Transport) transport.TLSClientConfig = &tls.Config{ diff --git a/pkg/validation/providers.go b/pkg/validation/providers.go index 345274d8..1acdee65 100644 --- a/pkg/validation/providers.go +++ b/pkg/validation/providers.go @@ -5,6 +5,7 @@ import ( "os" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" ) // validateProviders is the initial validation migration for multiple providrers @@ -64,7 +65,7 @@ func validateProvider(provider options.Provider, providerIDs map[string]struct{} // providerRequiresClientSecret checks if provider requires client secret to be set // or it can be omitted in favor of JWT token to authenticate oAuth client func providerRequiresClientSecret(provider options.Provider) bool { - if provider.Type == "entra-id" && *provider.MicrosoftEntraIDConfig.FederatedTokenAuth { + if provider.Type == "entra-id" && ptr.Deref(provider.MicrosoftEntraIDConfig.FederatedTokenAuth, false) { return false } @@ -96,9 +97,9 @@ func validateGoogleConfig(provider options.Provider) []string { hasAdminEmail := provider.GoogleConfig.AdminEmail != "" hasSAJSON := provider.GoogleConfig.ServiceAccountJSON != "" - useADC := provider.GoogleConfig.UseApplicationDefaultCredentials + useADC := ptr.Deref(provider.GoogleConfig.UseApplicationDefaultCredentials, false) - if !hasAdminEmail && !hasSAJSON && !(*useADC) { + if !hasAdminEmail && !hasSAJSON && !useADC { return msgs } @@ -107,7 +108,7 @@ func validateGoogleConfig(provider options.Provider) []string { } _, err := os.Stat(provider.GoogleConfig.ServiceAccountJSON) - if !(*useADC) { + if !useADC { if !hasSAJSON { msgs = append(msgs, "missing setting: google-service-account-json or google-use-application-default-credentials") } else if err != nil { @@ -123,7 +124,7 @@ func validateGoogleConfig(provider options.Provider) []string { func validateEntraConfig(provider options.Provider) []string { msgs := []string{} - if *provider.MicrosoftEntraIDConfig.FederatedTokenAuth { + if ptr.Deref(provider.MicrosoftEntraIDConfig.FederatedTokenAuth, false) { federatedTokenPath := os.Getenv("AZURE_FEDERATED_TOKEN_FILE") if federatedTokenPath == "" { diff --git a/pkg/validation/upstreams.go b/pkg/validation/upstreams.go index 2750e22c..4c537678 100644 --- a/pkg/validation/upstreams.go +++ b/pkg/validation/upstreams.go @@ -5,6 +5,7 @@ import ( "net/url" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" ) func validateUpstreams(upstreams options.UpstreamConfig) []string { @@ -54,28 +55,28 @@ func validateUpstream(upstream options.Upstream, ids, paths map[string]struct{}) func validateStaticUpstream(upstream options.Upstream) []string { msgs := []string{} - if !(*upstream.Static) && upstream.StaticCode != nil { + if !ptr.Deref(upstream.Static, false) && upstream.StaticCode != nil { msgs = append(msgs, fmt.Sprintf("upstream %q has staticCode (%d), but is not a static upstream, set 'static' for a static response", upstream.ID, *upstream.StaticCode)) } // Checks after this only make sense when the upstream is static - if !(*upstream.Static) { + if !ptr.Deref(upstream.Static, false) { return msgs } if upstream.URI != "" { msgs = append(msgs, fmt.Sprintf("upstream %q has uri, but is a static upstream, this will have no effect.", upstream.ID)) } - if *upstream.InsecureSkipTLSVerify { + if ptr.Deref(upstream.InsecureSkipTLSVerify, false) { msgs = append(msgs, fmt.Sprintf("upstream %q has insecureSkipTLSVerify, but is a static upstream, this will have no effect.", upstream.ID)) } - if upstream.FlushInterval != nil && *upstream.FlushInterval != options.DefaultUpstreamFlushInterval { + if ptr.Deref(upstream.FlushInterval, options.DefaultUpstreamFlushInterval) != options.DefaultUpstreamFlushInterval { msgs = append(msgs, fmt.Sprintf("upstream %q has flushInterval, but is a static upstream, this will have no effect.", upstream.ID)) } - if *upstream.PassHostHeader { + if ptr.Deref(upstream.PassHostHeader, false) { msgs = append(msgs, fmt.Sprintf("upstream %q has passHostHeader, but is a static upstream, this will have no effect.", upstream.ID)) } - if *upstream.ProxyWebSockets { + if ptr.Deref(upstream.ProxyWebSockets, false) { msgs = append(msgs, fmt.Sprintf("upstream %q has proxyWebSockets, but is a static upstream, this will have no effect.", upstream.ID)) } @@ -85,13 +86,13 @@ func validateStaticUpstream(upstream options.Upstream) []string { func validateUpstreamURI(upstream options.Upstream) []string { msgs := []string{} - if !(*upstream.Static) && upstream.URI == "" { + if !ptr.Deref(upstream.Static, false) && upstream.URI == "" { msgs = append(msgs, fmt.Sprintf("upstream %q has empty uri: uris are required for all non-static upstreams", upstream.ID)) return msgs } // Checks after this only make sense the upstream is not static - if *upstream.Static { + if !ptr.Deref(upstream.Static, false) { return msgs } diff --git a/providers/adfs.go b/providers/adfs.go index 6615f38c..ebf91f71 100644 --- a/providers/adfs.go +++ b/providers/adfs.go @@ -8,6 +8,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" ) // ADFSProvider represents an ADFS based Identity Provider @@ -50,7 +51,7 @@ func NewADFSProvider(p *ProviderData, opts options.Provider) *ADFSProvider { return &ADFSProvider{ OIDCProvider: oidcProvider, - skipScope: *opts.ADFSConfig.SkipScope, + skipScope: ptr.Deref(opts.ADFSConfig.SkipScope, false), oidcEnrichFunc: oidcProvider.EnrichSession, oidcRefreshFunc: oidcProvider.RefreshSession, } diff --git a/providers/google.go b/providers/google.go index 64102ebc..ac38980f 100644 --- a/providers/google.go +++ b/providers/google.go @@ -19,6 +19,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" "golang.org/x/oauth2" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" @@ -108,9 +109,7 @@ func NewGoogleProvider(p *ProviderData, opts options.GoogleOptions) (*GoogleProv }, } - if opts.UseOrganizationID || opts.ServiceAccountJSON != "" || *opts.UseApplicationDefaultCredentials { - provider.configureGroups(opts) - + if opts.UseOrganizationID || opts.ServiceAccountJSON != "" || ptr.Deref(opts.UseApplicationDefaultCredentials, false) { // reuse admin service to avoid multiple calls for token var adminService *admin.Service @@ -133,7 +132,7 @@ func NewGoogleProvider(p *ProviderData, opts options.GoogleOptions) (*GoogleProv } } - if opts.ServiceAccountJSON != "" || opts.UseApplicationDefaultCredentials { + if opts.ServiceAccountJSON != "" || ptr.Deref(opts.UseApplicationDefaultCredentials, false) { if adminService == nil { adminService = getAdminService(opts) } @@ -305,7 +304,7 @@ var possibleScopesList = [...]string{ } func getOauth2TokenSource(ctx context.Context, opts options.GoogleOptions, scope string) oauth2.TokenSource { - if *opts.UseApplicationDefaultCredentials { + if ptr.Deref(opts.UseApplicationDefaultCredentials, false) { ts, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{ TargetPrincipal: getTargetPrincipal(ctx, opts), Scopes: strings.Split(scope, " "), diff --git a/providers/ms_entra_id.go b/providers/ms_entra_id.go index 752f9f44..57c4fae1 100644 --- a/providers/ms_entra_id.go +++ b/providers/ms_entra_id.go @@ -16,6 +16,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" "github.com/spf13/cast" "golang.org/x/oauth2" ) @@ -51,7 +52,7 @@ func NewMicrosoftEntraIDProvider(p *ProviderData, opts options.Provider) *Micros OIDCProvider: NewOIDCProvider(p, opts.OIDCConfig), multiTenantAllowedTenants: opts.MicrosoftEntraIDConfig.AllowedTenants, - federatedTokenAuth: *opts.MicrosoftEntraIDConfig.FederatedTokenAuth, + federatedTokenAuth: ptr.Deref(opts.MicrosoftEntraIDConfig.FederatedTokenAuth, false), microsoftGraphURL: microsoftGraphURL, } } diff --git a/providers/oidc.go b/providers/oidc.go index fa65e839..5e28039d 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -12,6 +12,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" "golang.org/x/oauth2" ) @@ -50,7 +51,7 @@ func NewOIDCProvider(p *ProviderData, opts options.OIDCOptions) *OIDCProvider { return &OIDCProvider{ ProviderData: p, - SkipNonce: *opts.InsecureSkipNonce, + SkipNonce: ptr.Deref(opts.InsecureSkipNonce, false), } } diff --git a/providers/providers.go b/providers/providers.go index 84f5ec91..1c7ac652 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -9,6 +9,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" k8serrors "k8s.io/apimachinery/pkg/util/errors" ) @@ -98,8 +99,8 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, IssuerURL: providerConfig.OIDCConfig.IssuerURL, JWKsURL: providerConfig.OIDCConfig.JwksURL, PublicKeyFiles: providerConfig.OIDCConfig.PublicKeyFiles, - SkipDiscovery: *providerConfig.OIDCConfig.SkipDiscovery, - SkipIssuerVerification: *providerConfig.OIDCConfig.InsecureSkipIssuerVerification, + SkipDiscovery: ptr.Deref(providerConfig.OIDCConfig.SkipDiscovery, false), + SkipIssuerVerification: ptr.Deref(providerConfig.OIDCConfig.InsecureSkipIssuerVerification, false), }) if err != nil { return nil, fmt.Errorf("error building OIDC ProviderVerifier: %v", err) @@ -143,10 +144,10 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, } // Make the OIDC options available to all providers that support it - p.AllowUnverifiedEmail = *providerConfig.OIDCConfig.InsecureAllowUnverifiedEmail + p.AllowUnverifiedEmail = ptr.Deref(providerConfig.OIDCConfig.InsecureAllowUnverifiedEmail, false) p.EmailClaim = providerConfig.OIDCConfig.EmailClaim p.GroupsClaim = providerConfig.OIDCConfig.GroupsClaim - p.SkipClaimsFromProfileURL = *providerConfig.SkipClaimsFromProfileURL + p.SkipClaimsFromProfileURL = ptr.Deref(providerConfig.SkipClaimsFromProfileURL, false) // Set PKCE enabled or disabled based on discovery and force options p.CodeChallengeMethod = parseCodeChallengeMethod(providerConfig)