diff --git a/CHANGELOG.md b/CHANGELOG.md index fc4c2379..5e2ffdb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ ## Changes since v7.15.1 +- [#3396](https://github.com/oauth2-proxy/oauth2-proxy/pull/3396) feat: add support for specifying custom OIDC JWT Issuer Headers (@xnox) + # V7.15.1 ## Release Highlights diff --git a/contrib/oauth2-proxy_autocomplete.sh b/contrib/oauth2-proxy_autocomplete.sh index 0dd8d304..d1250aad 100644 --- a/contrib/oauth2-proxy_autocomplete.sh +++ b/contrib/oauth2-proxy_autocomplete.sh @@ -24,7 +24,7 @@ _oauth2_proxy() { COMPREPLY=( $(compgen -W 'X-Real-IP X-Forwarded-For X-ProxyUser-IP' -- ${cur}) ) return 0 ;; - --@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|trusted-ip|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|github-user|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|ready-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url|force-json-errors)) + --@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|extra-jwt-issuers-headers|email-domain|whitelist-domain|trusted-ip|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|github-user|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|ready-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url|force-json-errors)) return 0 ;; esac diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index b57d5aed..5d4af4e7 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -58,6 +58,7 @@ type Options struct { SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` BearerTokenLoginFallback bool `flag:"bearer-token-login-fallback" cfg:"bearer_token_login_fallback"` ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers"` + ExtraJwtIssuersHeaders []string `flag:"extra-jwt-issuers-headers" cfg:"extra_jwt_issuers_headers"` SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` @@ -135,6 +136,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Bool("encode-state", false, "will encode oauth state with base64") flagSet.Bool("allow-query-semicolons", false, "allow the use of semicolons in query args") flagSet.StringSlice("extra-jwt-issuers", []string{}, "if skip-jwt-bearer-tokens is set, a list of extra JWT issuer=audience pairs (where the issuer URL has a .well-known/openid-configuration or a .well-known/jwks.json)") + flagSet.StringSlice("extra-jwt-issuers-headers", []string{}, "Allows setting a header when .well-known/openid-configuration is called. specified as key=value") flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") flagSet.StringSlice("whitelist-domain", []string{}, "allowed domains for redirection after authentication. Prefix domain with a . or a *. to allow subdomains (eg .example.com, *.example.com)") diff --git a/pkg/providers/oidc/provider.go b/pkg/providers/oidc/provider.go index 427f79a6..716dfdca 100644 --- a/pkg/providers/oidc/provider.go +++ b/pkg/providers/oidc/provider.go @@ -47,7 +47,13 @@ type DiscoveryProvider interface { // We implement this here as opposed to using oidc.Provider so that we can override the Issuer verification check. // As we have our own verifier and fetch the userinfo separately, the rest of the oidc.Provider implementation is not // useful to us. -func NewProvider(ctx context.Context, issuerURL string, skipIssuerVerification bool) (DiscoveryProvider, error) { +// +// Parameters: +// ctx: The context for the function execution. +// issuerURL: The URL of the OIDC issuer to perform discovery against. +// skipIssuerVerification: A boolean flag indicating whether to skip issuer verification. +// IssuerCustomHeaders: A map of custom headers to be used when calling the issuer for discovery. +func NewProvider(ctx context.Context, issuerURL string, skipIssuerVerification bool, issuerCustomHeaders map[string]string) (DiscoveryProvider, error) { // go-oidc doesn't let us pass bypass the issuer check this in the oidc.NewProvider call // (which uses discovery to get the URLs), so we'll do a quick check ourselves and if // we get the URLs, we'll just use the non-discovery path. @@ -56,10 +62,17 @@ func NewProvider(ctx context.Context, issuerURL string, skipIssuerVerification b var p providerJSON requestURL := strings.TrimSuffix(issuerURL, "/") + "/.well-known/openid-configuration" - if err := requests.New(requestURL).WithContext(ctx).Do().UnmarshalInto(&p); err != nil { - return nil, fmt.Errorf("failed to discover OIDC configuration: %v", err) + + request := requests.New(requestURL). + WithContext(ctx) + + for key, value := range issuerCustomHeaders { + request = request.SetHeader(key, value) } + if err := request.Do().UnmarshalInto(&p); err != nil { + return nil, fmt.Errorf("failed to discover OIDC configuration: %v", err) + } if !skipIssuerVerification && p.Issuer != issuerURL { return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuerURL, p.Issuer) } diff --git a/pkg/providers/oidc/provider_test.go b/pkg/providers/oidc/provider_test.go index 0364721a..c29613a4 100644 --- a/pkg/providers/oidc/provider_test.go +++ b/pkg/providers/oidc/provider_test.go @@ -37,7 +37,7 @@ var _ = Describe("Provider", func() { Expect(m.Shutdown()).To(Succeed()) }() - provider, err := NewProvider(context.Background(), m.Issuer(), in.skipIssuerVerification) + provider, err := NewProvider(context.Background(), m.Issuer(), in.skipIssuerVerification, make(map[string]string)) if in.expectedError != "" { Expect(err).To(MatchError(HavePrefix(in.expectedError))) return @@ -97,7 +97,7 @@ var _ = Describe("Provider", func() { Expect(m.Shutdown()).To(Succeed()) }() - provider, err := NewProvider(context.Background(), m.Issuer(), false) + provider, err := NewProvider(context.Background(), m.Issuer(), false, make(map[string]string)) Expect(err).ToNot(HaveOccurred()) Expect(provider.PKCE().CodeChallengeAlgs).To(ConsistOf("S256", "plain")) @@ -116,7 +116,7 @@ var _ = Describe("Provider", func() { Expect(m.Shutdown()).To(Succeed()) }() - provider, err := NewProvider(context.Background(), m.Issuer(), false) + provider, err := NewProvider(context.Background(), m.Issuer(), false, make(map[string]string)) Expect(err).ToNot(HaveOccurred()) Expect(provider.SupportedSigningAlgs()).To(ConsistOf("RS256", "HS256")) diff --git a/pkg/providers/oidc/provider_verifier.go b/pkg/providers/oidc/provider_verifier.go index 0457a9dc..699823ab 100644 --- a/pkg/providers/oidc/provider_verifier.go +++ b/pkg/providers/oidc/provider_verifier.go @@ -38,6 +38,9 @@ type ProviderVerifierOptions struct { // eg: https://accounts.google.com IssuerURL string + // IssuerCustomHeaders defines optional header values to be set when calling the given provider. + IssuerCustomHeaders map[string]string + // JWKsURL is the OpenID Connect JWKS URL // eg: https://www.googleapis.com/oauth2/v3/certs JWKsURL string @@ -150,7 +153,7 @@ func getVerifierBuilder(ctx context.Context, opts ProviderVerifierOptions) (veri ), nil, nil } - provider, err := NewProvider(ctx, opts.IssuerURL, opts.SkipIssuerVerification) + provider, err := NewProvider(ctx, opts.IssuerURL, opts.SkipIssuerVerification, opts.IssuerCustomHeaders) if err != nil { return nil, nil, fmt.Errorf("error while discovery OIDC configuration: %w", err) } diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 13ce2e0b..e66525c4 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -57,11 +57,20 @@ func Validate(o *options.Options) error { if len(o.ExtraJwtIssuers) > 0 { var jwtIssuers []jwtIssuer jwtIssuers, msgs = parseJwtIssuers(o.ExtraJwtIssuers, msgs) + + var jwtIssuersHeaders map[string]string + if len(o.ExtraJwtIssuersHeaders) > 0 { + jwtIssuersHeaders, msgs = parseJwtIssuerHeader(o.ExtraJwtIssuersHeaders, msgs) + } else { + jwtIssuersHeaders = make(map[string]string) // Initialize an empty map if headers are not provided + } + for _, jwtIssuer := range jwtIssuers { verifier, err := newVerifierFromJwtIssuer( o.Providers[0].OIDCConfig.AudienceClaims, o.Providers[0].OIDCConfig.ExtraAudiences, jwtIssuer, + jwtIssuersHeaders, ) if err != nil { msgs = append(msgs, fmt.Sprintf("error building verifiers: %s", err)) @@ -141,14 +150,61 @@ func parseJwtIssuers(issuers []string, msgs []string) ([]jwtIssuer, []string) { return parsedIssuers, msgs } +// parseJwtIssuerHeader takes in an array of header strings in the form of "headerKey=headerValue" +// and parses them to return a map of key-value pairs and any error messages. +// +// Parameters: +// +// headers: A slice of strings representing headerKey=headerValue entries. +// msgs: A slice of strings to collect error messages, if any. +// +// Returns: +// +// map[string]string: A map of key-value pairs extracted from the headers. +// []string: A slice of strings containing any error messages encountered during parsing. +// +// Description: +// This function parses the input headers and extracts key-value pairs from them. +// Each entry in the "headers" slice should be in the format "headerKey=headerValue". +// The function checks if each entry contains both a non-empty key and a non-empty value. +// If so, it adds the key-value pair to the resulting map. If any errors are encountered +// during parsing, they are appended to the "msgs" slice. +// The function returns the map of key-value pairs and the error messages. +func parseJwtIssuerHeader(headers []string, msgs []string) (map[string]string, []string) { + result := make(map[string]string) + + if len(headers) == 0 { + msgs = append(msgs, "empty header array") + return result, msgs + } + + for _, headerItem := range headers { + components := strings.SplitN(strings.TrimSpace(headerItem), "=", 2) + if len(components) != 2 { + msgs = append(msgs, fmt.Sprintf("invalid jwt issuer header format, expected header_name=header_value: %s", headerItem)) + continue + } + + key := strings.TrimSpace(components[0]) + value := strings.TrimSpace(components[1]) + if key != "" && value != "" { + // Add the non-empty key and value to the result map + result[key] = value + } + } + + return result, msgs +} + // newVerifierFromJwtIssuer takes in issuer information in jwtIssuer info and returns // a verifier for that issuer. -func newVerifierFromJwtIssuer(audienceClaims []string, extraAudiences []string, jwtIssuer jwtIssuer) (internaloidc.IDTokenVerifier, error) { +func newVerifierFromJwtIssuer(audienceClaims []string, extraAudiences []string, jwtIssuer jwtIssuer, jwtIssuersHeaders map[string]string) (internaloidc.IDTokenVerifier, error) { pvOpts := internaloidc.ProviderVerifierOptions{ - AudienceClaims: audienceClaims, - ClientID: jwtIssuer.audience, - ExtraAudiences: extraAudiences, - IssuerURL: jwtIssuer.issuerURI, + AudienceClaims: audienceClaims, + ClientID: jwtIssuer.audience, + ExtraAudiences: extraAudiences, + IssuerURL: jwtIssuer.issuerURI, + IssuerCustomHeaders: jwtIssuersHeaders, } pv, err := internaloidc.NewProviderVerifier(context.TODO(), pvOpts)