diff --git a/main.go b/main.go index 42e8bab0..1c6a6b21 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,8 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/version" "github.com/spf13/pflag" @@ -59,6 +61,21 @@ func main() { logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err) } + // Set up dynamic CA certificate reloading if CA files are configured + if !opts.SSLInsecureSkipVerify && len(opts.Providers[0].CAFiles) > 0 { + caLoader, err := requests.NewDynamicCALoader( + opts.Providers[0].CAFiles, + ptr.Deref(opts.Providers[0].UseSystemTrustStore, options.DefaultUseSystemTrustStore), + ) + if err != nil { + logger.Fatalf("ERROR: Failed to load CA certificates: %v", err) + } + requests.SetDefaultTransport(requests.NewDynamicTLSTransport(caLoader)) + if err := caLoader.StartWatching(nil); err != nil { + logger.Printf("WARNING: Failed to start CA file watching: %v", err) + } + } + if err := oauthproxy.Start(); err != nil { logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err) } diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 13ce2e0b..45bc4e83 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/url" + "os" "strings" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" @@ -14,8 +15,6 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" ) // Validate checks that required options are set and validates those that they @@ -32,18 +31,15 @@ func Validate(o *options.Options) error { msgs = parseSignatureKey(o, msgs) if o.SSLInsecureSkipVerify { - transport := requests.DefaultTransport.(*http.Transport) + transport := http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 -- InsecureSkipVerify is a configurable option we allow + requests.SetDefaultTransport(transport) } else if len(o.Providers[0].CAFiles) > 0 { - pool, err := util.GetCertPool(o.Providers[0].CAFiles, ptr.Deref(o.Providers[0].UseSystemTrustStore, options.DefaultUseSystemTrustStore)) - if err == nil { - transport := requests.DefaultTransport.(*http.Transport) - transport.TLSClientConfig = &tls.Config{ - RootCAs: pool, - MinVersion: tls.VersionTLS12, + // Validate CA files are readable (actual loading happens in main.go) + for _, caFile := range o.Providers[0].CAFiles { + if _, err := os.Stat(caFile); err != nil { + msgs = append(msgs, fmt.Sprintf("unable to load provider CA file(s): %v", err)) } - } else { - msgs = append(msgs, fmt.Sprintf("unable to load provider CA file(s): %v", err)) } }