From a1dec50415a3c2ab7ba1e75c2a6358787d02a2b6 Mon Sep 17 00:00:00 2001 From: Dimitri John Ledkov Date: Thu, 2 Apr 2026 12:24:54 +0100 Subject: [PATCH] test: add NewProvider test case for IssuerCustomHeaders Verifies that custom headers passed to NewProvider are included in the OIDC discovery request, using a middleware that rejects requests missing the expected header. Signed-off-by: Dimitri John Ledkov --- pkg/providers/oidc/provider_test.go | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/pkg/providers/oidc/provider_test.go b/pkg/providers/oidc/provider_test.go index c29613a4..12ee9e0e 100644 --- a/pkg/providers/oidc/provider_test.go +++ b/pkg/providers/oidc/provider_test.go @@ -15,6 +15,7 @@ var _ = Describe("Provider", func() { type newProviderTableInput struct { skipIssuerVerification bool expectedError string + customHeaders map[string]string middlewares func(*mockoidc.MockOIDC) []func(http.Handler) http.Handler } @@ -37,7 +38,11 @@ var _ = Describe("Provider", func() { Expect(m.Shutdown()).To(Succeed()) }() - provider, err := NewProvider(context.Background(), m.Issuer(), in.skipIssuerVerification, make(map[string]string)) + customHeaders := in.customHeaders + if customHeaders == nil { + customHeaders = make(map[string]string) + } + provider, err := NewProvider(context.Background(), m.Issuer(), in.skipIssuerVerification, customHeaders) if in.expectedError != "" { Expect(err).To(MatchError(HavePrefix(in.expectedError))) return @@ -82,6 +87,15 @@ var _ = Describe("Provider", func() { }, expectedError: "failed to discover OIDC configuration: unexpected status \"400\"", }), + Entry("with custom headers, sends them in the discovery request", &newProviderTableInput{ + skipIssuerVerification: false, + customHeaders: map[string]string{"X-Custom-Header": "custom-value"}, + middlewares: func(m *mockoidc.MockOIDC) []func(http.Handler) http.Handler { + return []func(http.Handler) http.Handler{ + newRequiredHeaderMiddleware("X-Custom-Header", "custom-value"), + } + }, + }), ) It("with code challenges supported on the provider, shold populate PKCE information", func() { @@ -189,3 +203,15 @@ func newBadRequestMiddleware() func(http.Handler) http.Handler { }) } } + +func newRequiredHeaderMiddleware(key, value string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Header.Get(key) != value { + rw.WriteHeader(http.StatusUnauthorized) + return + } + next.ServeHTTP(rw, req) + }) + } +}