diff --git a/CHANGELOG.md b/CHANGELOG.md index 2aa2ef46..33a03e0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ## Changes since v7.13.0 - [#3197](https://github.com/oauth2-proxy/oauth2-proxy/pull/3197) fix: NewRemoteKeySet is not using DefaultHTTPClient (@rsrdesarrollo / @tuunit) +- [#2869](https://github.com/oauth2-proxy/oauth2-proxy/pull/2869) feat: pass along oidc config to verifier (@brian-mcnamara) # V7.13.0 diff --git a/pkg/providers/oidc/provider_verifier.go b/pkg/providers/oidc/provider_verifier.go index eac80a8c..7dfb321f 100644 --- a/pkg/providers/oidc/provider_verifier.go +++ b/pkg/providers/oidc/provider_verifier.go @@ -17,6 +17,7 @@ import ( // ProviderVerifier represents the OIDC discovery and verification process type ProviderVerifier interface { DiscoveryEnabled() bool + JwksURL() string Provider() DiscoveryProvider Verifier() IDTokenVerifier } @@ -120,6 +121,7 @@ func NewProviderVerifier(ctx context.Context, opts ProviderVerifierOptions) (Pro return &providerVerifier{ discoveryEnabled: !opts.SkipDiscovery, + jwksURL: opts.JWKsURL, provider: provider, verifier: verifier, }, nil @@ -215,6 +217,7 @@ func newVerifierBuilder(issuerURL string, keySet oidc.KeySet, supportedSigningAl // providerVerifier is an implementation of the ProviderVerifier interface type providerVerifier struct { discoveryEnabled bool + jwksURL string provider DiscoveryProvider verifier IDTokenVerifier } @@ -234,3 +237,8 @@ func (p *providerVerifier) Provider() DiscoveryProvider { func (p *providerVerifier) Verifier() IDTokenVerifier { return p.verifier } + +// JwksURL retuns the jwks url used for the provider +func (p *providerVerifier) JwksURL() string { + return p.jwksURL +} diff --git a/pkg/providers/oidc/provider_verifier_test.go b/pkg/providers/oidc/provider_verifier_test.go index ff91e016..17e9592f 100644 --- a/pkg/providers/oidc/provider_verifier_test.go +++ b/pkg/providers/oidc/provider_verifier_test.go @@ -2,6 +2,8 @@ package oidc import ( "context" + "net" + "net/http" "os" "path/filepath" "time" @@ -45,10 +47,24 @@ var _ = AfterSuite(func() { var _ = Describe("ProviderVerifier", func() { var m *mockoidc.MockOIDC + var endpointsHit []string BeforeEach(func() { var err error - m, err = mockoidc.Run() + m, err = mockoidc.NewServer(nil) + Expect(err).ToNot(HaveOccurred()) + ln, err := net.Listen("tcp", "127.0.0.1:0") + Expect(err).ToNot(HaveOccurred()) + + endpointsHit = []string{} + m.AddMiddleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + endpointsHit = append(endpointsHit, req.URL.Path) + next.ServeHTTP(rw, req) + }) + }) + + err = m.Start(ln, nil) Expect(err).ToNot(HaveOccurred()) }) @@ -88,9 +104,12 @@ var _ = Describe("ProviderVerifier", func() { Expect(endpoints.TokenURL).To(Equal(m.TokenEndpoint())) Expect(endpoints.JWKsURL).To(Equal(m.JWKSEndpoint())) Expect(endpoints.UserInfoURL).To(Equal(m.UserinfoEndpoint())) + Expect(endpointsHit).To(ContainElement("/oidc/.well-known/openid-configuration")) + } else { + Expect(endpointsHit).ToNot(ContainElement("/oidc/.well-known/openid-configuration")) } }, - Entry("should be succesfful when discovering the OIDC provider", &newProviderVerifierTableInput{ + Entry("should be successful when discovering the OIDC provider", &newProviderVerifierTableInput{ modifyOpts: func(_ *ProviderVerifierOptions) {}, }), Entry("when the issuer URL is missing", &newProviderVerifierTableInput{ diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 13ce2e0b..85b14d0b 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -59,8 +59,7 @@ func Validate(o *options.Options) error { jwtIssuers, msgs = parseJwtIssuers(o.ExtraJwtIssuers, msgs) for _, jwtIssuer := range jwtIssuers { verifier, err := newVerifierFromJwtIssuer( - o.Providers[0].OIDCConfig.AudienceClaims, - o.Providers[0].OIDCConfig.ExtraAudiences, + o.Providers[0].OIDCConfig, jwtIssuer, ) if err != nil { @@ -143,12 +142,26 @@ func parseJwtIssuers(issuers []string, msgs []string) ([]jwtIssuer, []string) { // newVerifierFromJwtIssuer takes in issuer information in jwtIssuer info and returns // a verifier for that issuer. -func newVerifierFromJwtIssuer(audienceClaims []string, extraAudiences []string, jwtIssuer jwtIssuer) (internaloidc.IDTokenVerifier, error) { +func newVerifierFromJwtIssuer(odicOptions options.OIDCOptions, jwtIssuer jwtIssuer) (internaloidc.IDTokenVerifier, error) { + pv, err := newProviderVerifierFromJwtIssuer(odicOptions, jwtIssuer) + + if err != nil { + return nil, err + } + + return pv.Verifier(), nil +} + +// newProviderVerifierFromJwtIssuer takes in issuer information in jwtIssuer info and returns +// a ProviderVerifier for that issuer. +func newProviderVerifierFromJwtIssuer(odicOptions options.OIDCOptions, jwtIssuer jwtIssuer) (internaloidc.ProviderVerifier, error) { pvOpts := internaloidc.ProviderVerifierOptions{ - AudienceClaims: audienceClaims, + AudienceClaims: odicOptions.AudienceClaims, ClientID: jwtIssuer.audience, - ExtraAudiences: extraAudiences, + ExtraAudiences: odicOptions.ExtraAudiences, IssuerURL: jwtIssuer.issuerURI, + SkipDiscovery: odicOptions.SkipDiscovery, + JWKsURL: odicOptions.JwksURL, } pv, err := internaloidc.NewProviderVerifier(context.TODO(), pvOpts) @@ -163,7 +176,7 @@ func newVerifierFromJwtIssuer(audienceClaims []string, extraAudiences []string, } } - return pv.Verifier(), nil + return pv, nil } // jwtIssuer hold parsed JWT issuer info that's used to construct a verifier. diff --git a/pkg/validation/options_test.go b/pkg/validation/options_test.go index 5ea748c1..639cdb28 100644 --- a/pkg/validation/options_test.go +++ b/pkg/validation/options_test.go @@ -236,3 +236,39 @@ func TestProviderCAFilesError(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "unable to load provider CA file(s)") } + +func TestProviderVerifierSkipsDiscoveryWhenConfigured(t *testing.T) { + jwksUrl := "https://examle.com/auth/certs" + o := testOptions() + o.Providers[0].OIDCConfig.SkipDiscovery = true + o.Providers[0].OIDCConfig.JwksURL = jwksUrl + + jwtIssuer := jwtIssuer{ + issuerURI: "https://example.com", + audience: "aud", + } + + pv, err := newProviderVerifierFromJwtIssuer(o.Providers[0].OIDCConfig, jwtIssuer) + assert.Equal(t, nil, err) + assert.NotEqual(t, nil, pv) + + assert.Equal(t, jwksUrl, pv.JwksURL()) + assert.False(t, pv.DiscoveryEnabled()) +} + +func TestProviderVerifierUsesFallback(t *testing.T) { + issuerURI := "https://example.com" + o := testOptions() + + jwtIssuer := jwtIssuer{ + issuerURI: issuerURI, + audience: "aud", + } + + pv, err := newProviderVerifierFromJwtIssuer(o.Providers[0].OIDCConfig, jwtIssuer) + assert.Equal(t, nil, err) + assert.NotEqual(t, nil, pv) + + assert.Equal(t, issuerURI+"/.well-known/jwks.json", pv.JwksURL()) + assert.False(t, pv.DiscoveryEnabled()) +}