diff --git a/CHANGELOG.md b/CHANGELOG.md index 320ba697..a531ad50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ ## Changes since v7.15.2 +- [#xxxx](https://github.com/oauth2-proxy/oauth2-proxy/pull/xxxx) feat: add `--trusted-issuer-prefix` flag for dynamic JWT verification via issuer URL prefixes + # V7.15.2 ## Release Highlights diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index ac2b13c8..c184ea70 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -59,6 +59,7 @@ type Options struct { SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` BearerTokenLoginFallback bool `flag:"bearer-token-login-fallback" cfg:"bearer_token_login_fallback"` ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers"` + TrustedIssuerPrefixes []string `flag:"trusted-issuer-prefix" cfg:"trusted_issuer_prefixes"` SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` @@ -137,6 +138,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Bool("encode-state", false, "will encode oauth state with base64") flagSet.Bool("allow-query-semicolons", false, "allow the use of semicolons in query args") flagSet.StringSlice("extra-jwt-issuers", []string{}, "if skip-jwt-bearer-tokens is set, a list of extra JWT issuer=audience pairs (where the issuer URL has a .well-known/openid-configuration or a .well-known/jwks.json)") + flagSet.StringSlice("trusted-issuer-prefix", []string{}, "if skip-jwt-bearer-tokens is set, a list of issuer URL prefix=audience pairs. Any JWT whose issuer starts with the prefix will be dynamically verified via OIDC discovery (e.g. https://keycloak.example.com/realms/TENANT_=my-client-id)") flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") flagSet.StringSlice("whitelist-domain", []string{}, "allowed domains for redirection after authentication. Prefix domain with a . or a *. to allow subdomains (eg .example.com, *.example.com)") diff --git a/pkg/providers/oidc/prefix_verifier.go b/pkg/providers/oidc/prefix_verifier.go new file mode 100644 index 00000000..4c4d9922 --- /dev/null +++ b/pkg/providers/oidc/prefix_verifier.go @@ -0,0 +1,157 @@ +package oidc + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" +) + +// PrefixVerifier dynamically verifies JWT tokens whose issuer matches a configured +// URL prefix. For each unique issuer discovered at runtime, it performs OIDC discovery +// and caches the resulting verifier. +type PrefixVerifier struct { + prefix string + audience string + audienceClaims []string + extraAudiences []string + + mu sync.RWMutex + verifiers map[string]IDTokenVerifier +} + +// PrefixVerifierOptions configures a PrefixVerifier. +type PrefixVerifierOptions struct { + // Prefix is the issuer URL prefix to match (e.g. "https://keycloak.example.com/realms/TENANT_") + Prefix string + // Audience is the expected audience (client_id) + Audience string + // AudienceClaims specifies which claims to check for audience + AudienceClaims []string + // ExtraAudiences are additional allowed audiences + ExtraAudiences []string +} + +// NewPrefixVerifier creates a new PrefixVerifier. +func NewPrefixVerifier(opts PrefixVerifierOptions) *PrefixVerifier { + return &PrefixVerifier{ + prefix: opts.Prefix, + audience: opts.Audience, + audienceClaims: opts.AudienceClaims, + extraAudiences: opts.ExtraAudiences, + verifiers: make(map[string]IDTokenVerifier), + } +} + +// Verify checks if the token's issuer matches the prefix and verifies it. +func (pv *PrefixVerifier) Verify(ctx context.Context, rawToken string) (*oidc.IDToken, error) { + issuer, err := extractIssuerFromJWT(rawToken) + if err != nil { + return nil, fmt.Errorf("prefix verifier: failed to extract issuer: %w", err) + } + + if !strings.HasPrefix(issuer, pv.prefix) { + return nil, fmt.Errorf("prefix verifier: issuer %q does not match prefix %q", issuer, pv.prefix) + } + + verifier, err := pv.getOrCreateVerifier(ctx, issuer) + if err != nil { + return nil, fmt.Errorf("prefix verifier: failed to get verifier for issuer %q: %w", issuer, err) + } + + return verifier.Verify(ctx, rawToken) +} + +// getOrCreateVerifier retrieves a cached verifier or creates one via OIDC discovery. +func (pv *PrefixVerifier) getOrCreateVerifier(ctx context.Context, issuer string) (IDTokenVerifier, error) { + // Fast path: check cache with read lock + pv.mu.RLock() + v, ok := pv.verifiers[issuer] + pv.mu.RUnlock() + if ok { + return v, nil + } + + // Slow path: create verifier with write lock (double-checked locking) + pv.mu.Lock() + defer pv.mu.Unlock() + + // Re-check after acquiring write lock + if v, ok := pv.verifiers[issuer]; ok { + return v, nil + } + + verifier, err := pv.createVerifier(ctx, issuer) + if err != nil { + return nil, err + } + + pv.verifiers[issuer] = verifier + return verifier, nil +} + +// createVerifier performs OIDC discovery for the given issuer and creates a verifier. +func (pv *PrefixVerifier) createVerifier(ctx context.Context, issuer string) (IDTokenVerifier, error) { + ctx = oidc.ClientContext(ctx, requests.DefaultHTTPClient) + + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + // Fall back to JWKs URL if discovery fails + jwksURL := strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json" + keySet := oidc.NewRemoteKeySet(ctx, jwksURL) + oidcConfig := &oidc.Config{ + ClientID: pv.audience, + SkipIssuerCheck: false, + SkipClientIDCheck: true, + } + rawVerifier := oidc.NewVerifier(issuer, keySet, oidcConfig) + return NewVerifier(rawVerifier, IDTokenVerificationOptions{ + AudienceClaims: pv.audienceClaims, + ClientID: pv.audience, + ExtraAudiences: pv.extraAudiences, + }), nil + } + + oidcConfig := &oidc.Config{ + ClientID: pv.audience, + SkipIssuerCheck: false, + SkipClientIDCheck: true, + } + rawVerifier := provider.Verifier(oidcConfig) + return NewVerifier(rawVerifier, IDTokenVerificationOptions{ + AudienceClaims: pv.audienceClaims, + ClientID: pv.audience, + ExtraAudiences: pv.extraAudiences, + }), nil +} + +// extractIssuerFromJWT decodes the payload of a JWT (without verification) to extract the "iss" claim. +func extractIssuerFromJWT(rawToken string) (string, error) { + parts := strings.Split(rawToken, ".") + if len(parts) != 3 { + return "", fmt.Errorf("token has %d parts, expected 3", len(parts)) + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("failed to decode JWT payload: %w", err) + } + + var claims struct { + Issuer string `json:"iss"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return "", fmt.Errorf("failed to unmarshal JWT claims: %w", err) + } + + if claims.Issuer == "" { + return "", fmt.Errorf("JWT has no issuer claim") + } + + return claims.Issuer, nil +} diff --git a/pkg/providers/oidc/prefix_verifier_test.go b/pkg/providers/oidc/prefix_verifier_test.go new file mode 100644 index 00000000..de6018cf --- /dev/null +++ b/pkg/providers/oidc/prefix_verifier_test.go @@ -0,0 +1,321 @@ +package oidc + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("PrefixVerifier", func() { + var ( + privateKey *rsa.PrivateKey + server *httptest.Server + issuerURL string + ) + + BeforeEach(func() { + var err error + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + + // Set up a mock OIDC server that serves discovery and JWKS + mux := http.NewServeMux() + server = httptest.NewServer(mux) + issuerURL = server.URL + "/realms/SOT_TENANT_A" + + // OpenID Configuration endpoint + mux.HandleFunc("/realms/SOT_TENANT_A/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + config := map[string]interface{}{ + "issuer": issuerURL, + "jwks_uri": server.URL + "/realms/SOT_TENANT_A/.well-known/jwks.json", + "authorization_endpoint": server.URL + "/realms/SOT_TENANT_A/protocol/openid-connect/auth", + "token_endpoint": server.URL + "/realms/SOT_TENANT_A/protocol/openid-connect/token", + "id_token_signing_alg_values_supported": []string{"RS256"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(config) + }) + + // JWKS endpoint + mux.HandleFunc("/realms/SOT_TENANT_A/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { + n := base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()) + e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(privateKey.E)).Bytes()) + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "RSA", + "kid": "test-key-1", + "use": "sig", + "alg": "RS256", + "n": n, + "e": e, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(jwks) + }) + }) + + AfterEach(func() { + if server != nil { + server.Close() + } + }) + + Context("extractIssuerFromJWT", func() { + It("should extract the issuer from a valid JWT", func() { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": "https://example.com/realms/TEST", + "sub": "user-1", + "aud": "my-client", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + issuer, err := extractIssuerFromJWT(signed) + Expect(err).ToNot(HaveOccurred()) + Expect(issuer).To(Equal("https://example.com/realms/TEST")) + }) + + It("should return error for invalid JWT format", func() { + _, err := extractIssuerFromJWT("not-a-jwt") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("expected 3")) + }) + + It("should return error for JWT without issuer", func() { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "sub": "user-1", + "exp": time.Now().Add(time.Hour).Unix(), + }) + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + _, err = extractIssuerFromJWT(signed) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no issuer claim")) + }) + }) + + Context("Verify", func() { + It("should reject token whose issuer does not match prefix", func() { + pv := NewPrefixVerifier(PrefixVerifierOptions{ + Prefix: "https://keycloak.example.com/realms/SOT_", + Audience: "my-client", + AudienceClaims: []string{"aud"}, + }) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": "https://evil.example.com/realms/SOT_HACK", + "sub": "user-1", + "aud": "my-client", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + _, err = pv.Verify(context.Background(), signed) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("does not match prefix")) + }) + + It("should verify a valid token with matching prefix", func() { + prefix := server.URL + "/realms/SOT_" + pv := NewPrefixVerifier(PrefixVerifierOptions{ + Prefix: prefix, + Audience: "my-client", + AudienceClaims: []string{"aud"}, + }) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuerURL, + "sub": "user-1", + "aud": "my-client", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + idToken, err := pv.Verify(context.Background(), signed) + Expect(err).ToNot(HaveOccurred()) + Expect(idToken).ToNot(BeNil()) + Expect(idToken.Issuer).To(Equal(issuerURL)) + Expect(idToken.Subject).To(Equal("user-1")) + }) + + It("should cache verifiers for repeated calls", func() { + prefix := server.URL + "/realms/SOT_" + pv := NewPrefixVerifier(PrefixVerifierOptions{ + Prefix: prefix, + Audience: "my-client", + AudienceClaims: []string{"aud"}, + }) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuerURL, + "sub": "user-1", + "aud": "my-client", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + // First call — triggers OIDC discovery + _, err = pv.Verify(context.Background(), signed) + Expect(err).ToNot(HaveOccurred()) + + // Second call — should use cached verifier + _, err = pv.Verify(context.Background(), signed) + Expect(err).ToNot(HaveOccurred()) + + // Verify cache has one entry + pv.mu.RLock() + Expect(pv.verifiers).To(HaveLen(1)) + Expect(pv.verifiers).To(HaveKey(issuerURL)) + pv.mu.RUnlock() + }) + + It("should reject token with wrong audience", func() { + prefix := server.URL + "/realms/SOT_" + pv := NewPrefixVerifier(PrefixVerifierOptions{ + Prefix: prefix, + Audience: "my-client", + AudienceClaims: []string{"aud"}, + }) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuerURL, + "sub": "user-1", + "aud": "wrong-client", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + _, err = pv.Verify(context.Background(), signed) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("audience")) + }) + + It("should reject an expired token", func() { + prefix := server.URL + "/realms/SOT_" + pv := NewPrefixVerifier(PrefixVerifierOptions{ + Prefix: prefix, + Audience: "my-client", + AudienceClaims: []string{"aud"}, + }) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuerURL, + "sub": "user-1", + "aud": "my-client", + "exp": time.Now().Add(-time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + _, err = pv.Verify(context.Background(), signed) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to verify token")) + }) + + It("should reject token signed with wrong key", func() { + prefix := server.URL + "/realms/SOT_" + pv := NewPrefixVerifier(PrefixVerifierOptions{ + Prefix: prefix, + Audience: "my-client", + AudienceClaims: []string{"aud"}, + }) + + // Generate a different key + wrongKey, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuerURL, + "sub": "user-1", + "aud": "my-client", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(wrongKey) + Expect(err).ToNot(HaveOccurred()) + + _, err = pv.Verify(context.Background(), signed) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to verify token")) + }) + }) + + Context("prefix matching security", func() { + It("should not match partial prefix overlaps", func() { + pv := NewPrefixVerifier(PrefixVerifierOptions{ + Prefix: "https://keycloak.example.com/realms/SOT_", + Audience: "my-client", + AudienceClaims: []string{"aud"}, + }) + + // Issuer that looks similar but comes from a different host + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": "https://keycloak.example.com.evil.com/realms/SOT_HACK", + "sub": "user-1", + "aud": "my-client", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + _, err = pv.Verify(context.Background(), signed) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("does not match prefix")) + }) + + It("should match exact prefix boundary", func() { + pv := NewPrefixVerifier(PrefixVerifierOptions{ + Prefix: server.URL + "/realms/SOT_", + Audience: "my-client", + AudienceClaims: []string{"aud"}, + }) + + // SOT_TENANT_A starts with SOT_ ✓ + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": issuerURL, // ends with /realms/SOT_TENANT_A + "sub": "user-1", + "aud": "my-client", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-1" + signed, err := token.SignedString(privateKey) + Expect(err).ToNot(HaveOccurred()) + + idToken, err := pv.Verify(context.Background(), signed) + Expect(err).ToNot(HaveOccurred()) + Expect(idToken.Issuer).To(Equal(issuerURL)) + }) + }) +}) + +// Suppress unused import warnings +var _ = fmt.Sprintf +var _ = strings.HasPrefix diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 13ce2e0b..213bc741 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -69,6 +69,21 @@ func Validate(o *options.Options) error { o.SetJWTBearerVerifiers(append(o.GetJWTBearerVerifiers(), verifier)) } } + + // Configure trusted issuer prefixes (dynamic verifiers) + if len(o.TrustedIssuerPrefixes) > 0 { + var prefixIssuers []jwtIssuer + prefixIssuers, msgs = parseJwtIssuers(o.TrustedIssuerPrefixes, msgs) + for _, pi := range prefixIssuers { + pv := internaloidc.NewPrefixVerifier(internaloidc.PrefixVerifierOptions{ + Prefix: pi.issuerURI, + Audience: pi.audience, + AudienceClaims: o.Providers[0].OIDCConfig.AudienceClaims, + ExtraAudiences: o.Providers[0].OIDCConfig.ExtraAudiences, + }) + o.SetJWTBearerVerifiers(append(o.GetJWTBearerVerifiers(), pv)) + } + } } var redirectURL *url.URL