diff --git a/pkg/requests/http.go b/pkg/requests/http.go index ed335b86..222b92d3 100644 --- a/pkg/requests/http.go +++ b/pkg/requests/http.go @@ -7,20 +7,22 @@ import ( ) type userAgentTransport struct { - next http.RoundTripper + Next http.RoundTripper userAgent string } func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { r := req.Clone(req.Context()) setDefaultUserAgent(r.Header, t.userAgent) - return t.next.RoundTrip(r) + return t.Next.RoundTrip(r) } -var DefaultHTTPClient = &http.Client{Transport: &userAgentTransport{ - next: http.DefaultTransport, +var DefaultHTTPClient = &http.Client{Transport: &DefaultTransport} + +var DefaultTransport = userAgentTransport{ + Next: http.DefaultTransport, userAgent: "oauth2-proxy/" + version.VERSION, -}} +} func setDefaultUserAgent(header http.Header, userAgent string) { if header != nil && len(header.Values("User-Agent")) == 0 { diff --git a/pkg/validation/options.go b/pkg/validation/options.go index b14439a7..caf896c5 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -13,6 +13,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" ) @@ -30,20 +31,20 @@ func Validate(o *options.Options) error { msgs = parseSignatureKey(o, msgs) if o.SSLInsecureSkipVerify { - insecureTransport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // #nosec G402 -- InsecureSkipVerify is a configurable option we allow - } - http.DefaultClient = &http.Client{Transport: insecureTransport} + transport := requests.DefaultTransport.Next.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 -- InsecureSkipVerify is a configurable option we allow + + requests.DefaultHTTPClient = &http.Client{Transport: transport} } else if len(o.Providers[0].CAFiles) > 0 { pool, err := util.GetCertPool(o.Providers[0].CAFiles, o.Providers[0].UseSystemTrustStore) if err == nil { - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := requests.DefaultTransport.Next.(*http.Transport).Clone() transport.TLSClientConfig = &tls.Config{ RootCAs: pool, MinVersion: tls.VersionTLS12, } - http.DefaultClient = &http.Client{Transport: transport} + requests.DefaultHTTPClient = &http.Client{Transport: transport} } else { msgs = append(msgs, fmt.Sprintf("unable to load provider CA file(s): %v", err)) }