diff --git a/pkg/providers/oidc/provider.go b/pkg/providers/oidc/provider.go index ddc8a2e1..d1b88644 100644 --- a/pkg/providers/oidc/provider.go +++ b/pkg/providers/oidc/provider.go @@ -11,11 +11,12 @@ import ( // providerJSON resresents the information we need from an OIDC discovery type providerJSON struct { - Issuer string `json:"issuer"` - AuthURL string `json:"authorization_endpoint"` - TokenURL string `json:"token_endpoint"` - JWKsURL string `json:"jwks_uri"` - UserInfoURL string `json:"userinfo_endpoint"` + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + JWKsURL string `json:"jwks_uri"` + UserInfoURL string `json:"userinfo_endpoint"` + CodeChallengeAlgs []string `json:"code_challenge_methods_supported"` } // Endpoints represents the endpoints discovered as part of the OIDC discovery process @@ -27,10 +28,17 @@ type Endpoints struct { UserInfoURL string } +// PKCE holds information relevant to the PKCE (code challenge) support of the +// provider. +type PKCE struct { + CodeChallengeAlgs []string +} + // DiscoveryProvider holds information about an identity provider having // used OIDC discovery to retrieve the information. type DiscoveryProvider interface { Endpoints() Endpoints + PKCE() PKCE } // NewProvider allows a user to perform an OIDC discovery and returns the DiscoveryProvider. @@ -55,19 +63,21 @@ func NewProvider(ctx context.Context, issuerURL string, skipIssuerVerification b } return &discoveryProvider{ - authURL: p.AuthURL, - tokenURL: p.TokenURL, - jwksURL: p.JWKsURL, - userInfoURL: p.UserInfoURL, + authURL: p.AuthURL, + tokenURL: p.TokenURL, + jwksURL: p.JWKsURL, + userInfoURL: p.UserInfoURL, + codeChallengeAlgs: p.CodeChallengeAlgs, }, nil } // discoveryProvider holds the discovered endpoints type discoveryProvider struct { - authURL string - tokenURL string - jwksURL string - userInfoURL string + authURL string + tokenURL string + jwksURL string + userInfoURL string + codeChallengeAlgs []string } // Endpoints returns the discovered endpoints needed for an authentication provider. @@ -79,3 +89,10 @@ func (p *discoveryProvider) Endpoints() Endpoints { UserInfoURL: p.userInfoURL, } } + +// PKCE returns information related to the PKCE (code challenge) support of the provider. +func (p *discoveryProvider) PKCE() PKCE { + return PKCE{ + CodeChallengeAlgs: p.codeChallengeAlgs, + } +} diff --git a/pkg/providers/oidc/provider_test.go b/pkg/providers/oidc/provider_test.go index 731bfb93..2f822e97 100644 --- a/pkg/providers/oidc/provider_test.go +++ b/pkg/providers/oidc/provider_test.go @@ -84,6 +84,25 @@ var _ = Describe("Provider", func() { expectedError: "failed to discover OIDC configuration: unexpected status \"400\"", }), ) + + It("with code challenges supported on the provider, shold populate PKCE information", func() { + m, err := mockoidc.NewServer(nil) + Expect(err).ToNot(HaveOccurred()) + m.AddMiddleware(newCodeChallengeIssuerMiddleware(m)) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + Expect(err).ToNot(HaveOccurred()) + + Expect(m.Start(ln, nil)).To(Succeed()) + defer func() { + Expect(m.Shutdown()).To(Succeed()) + }() + + provider, err := NewProvider(context.Background(), m.Issuer(), false) + Expect(err).ToNot(HaveOccurred()) + + Expect(provider.PKCE().CodeChallengeAlgs).To(ConsistOf("S256", "plain")) + }) }) func newInvalidIssuerMiddleware(m *mockoidc.MockOIDC) func(http.Handler) http.Handler { @@ -105,6 +124,26 @@ func newInvalidIssuerMiddleware(m *mockoidc.MockOIDC) func(http.Handler) http.Ha } } +func newCodeChallengeIssuerMiddleware(m *mockoidc.MockOIDC) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + p := providerJSON{ + Issuer: m.Issuer(), + AuthURL: m.AuthorizationEndpoint(), + TokenURL: m.TokenEndpoint(), + JWKsURL: m.JWKSEndpoint(), + UserInfoURL: m.UserinfoEndpoint(), + CodeChallengeAlgs: []string{"S256", "plain"}, + } + data, err := json.Marshal(p) + if err != nil { + rw.WriteHeader(500) + } + rw.Write(data) + }) + } +} + func newBadRequestMiddleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {