234 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			234 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Go
		
	
	
	
| package oidc
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/golang-jwt/jwt/v5"
 | |
| 	"github.com/oauth2-proxy/mockoidc"
 | |
| 	. "github.com/onsi/ginkgo/v2"
 | |
| 	. "github.com/onsi/gomega"
 | |
| )
 | |
| 
 | |
| var tempDir string
 | |
| var invalidPublicKeyFilePath string
 | |
| var validPublicKeyFilePath string
 | |
| 
 | |
| var _ = BeforeSuite(func() {
 | |
| 	var err error
 | |
| 
 | |
| 	// Create a temporary directory and public key file
 | |
| 	tempDir, err = os.MkdirTemp("/tmp", "provider-verifier-test")
 | |
| 	Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 	invalidPublicKeyFilePath = filepath.Join(tempDir, "invalid.key")
 | |
| 	validPublicKeyFilePath = filepath.Join(tempDir, "valid.key")
 | |
| 
 | |
| 	invalidKeyContents := []byte(`-----BEGIN INVALID KEY-----
 | |
| ThisIsNotAValidKey
 | |
| -----END INVALID KEY-----`)
 | |
| 	Expect(os.WriteFile(invalidPublicKeyFilePath, invalidKeyContents, 0644)).To(Succeed())
 | |
| 
 | |
| 	validKeyContents := []byte(`-----BEGIN PUBLIC KEY-----
 | |
| MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALBJK+8qU+aQu2bHxJ8E95AIu2NINztM
 | |
| NmX9R2zI9xlXN8wGQG8kWLYoRLbyiZwY9kdzOBGvYci64wHIjtFswHcCAwEAAQ==
 | |
| -----END PUBLIC KEY-----`)
 | |
| 	Expect(os.WriteFile(validPublicKeyFilePath, validKeyContents, 0644)).To(Succeed())
 | |
| })
 | |
| 
 | |
| var _ = AfterSuite(func() {
 | |
| 	// Clean up temporary directory
 | |
| 	Expect(os.RemoveAll(tempDir)).To(Succeed())
 | |
| })
 | |
| 
 | |
| var _ = Describe("ProviderVerifier", func() {
 | |
| 	var m *mockoidc.MockOIDC
 | |
| 
 | |
| 	BeforeEach(func() {
 | |
| 		var err error
 | |
| 		m, err = mockoidc.Run()
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 	})
 | |
| 
 | |
| 	AfterEach(func() {
 | |
| 		Expect(m.Shutdown()).To(Succeed())
 | |
| 	})
 | |
| 
 | |
| 	type newProviderVerifierTableInput struct {
 | |
| 		modifyOpts    func(*ProviderVerifierOptions)
 | |
| 		expectedError string
 | |
| 	}
 | |
| 
 | |
| 	DescribeTable("when constructing the provider verifier", func(in *newProviderVerifierTableInput) {
 | |
| 		opts := ProviderVerifierOptions{
 | |
| 			AudienceClaims: []string{"aud"},
 | |
| 			ClientID:       m.Config().ClientID,
 | |
| 			ExtraAudiences: []string{},
 | |
| 			IssuerURL:      m.Issuer(),
 | |
| 		}
 | |
| 		if in.modifyOpts != nil {
 | |
| 			in.modifyOpts(&opts)
 | |
| 		}
 | |
| 
 | |
| 		pv, err := NewProviderVerifier(context.Background(), opts)
 | |
| 		if in.expectedError != "" {
 | |
| 			Expect(err).To(MatchError(HavePrefix(in.expectedError)))
 | |
| 			return
 | |
| 		}
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		Expect(pv.DiscoveryEnabled()).ToNot(Equal(opts.SkipDiscovery), "DiscoveryEnabled should be the reverse of skip discovery")
 | |
| 		Expect(pv.Provider()).ToNot(BeNil())
 | |
| 
 | |
| 		if pv.DiscoveryEnabled() {
 | |
| 			endpoints := pv.Provider().Endpoints()
 | |
| 			Expect(endpoints.AuthURL).To(Equal(m.AuthorizationEndpoint()))
 | |
| 			Expect(endpoints.TokenURL).To(Equal(m.TokenEndpoint()))
 | |
| 			Expect(endpoints.JWKsURL).To(Equal(m.JWKSEndpoint()))
 | |
| 			Expect(endpoints.UserInfoURL).To(Equal(m.UserinfoEndpoint()))
 | |
| 		}
 | |
| 	},
 | |
| 		Entry("should be succesfful when discovering the OIDC provider", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(_ *ProviderVerifierOptions) {},
 | |
| 		}),
 | |
| 		Entry("when the issuer URL is missing", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.IssuerURL = ""
 | |
| 			},
 | |
| 			expectedError: "invalid provider verifier options: missing required setting: issuer-url",
 | |
| 		}),
 | |
| 		Entry("when the issuer URL is invalid", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.IssuerURL = "invalid"
 | |
| 			},
 | |
| 			expectedError: "could not get verifier builder: error while discovery OIDC configuration: failed to discover OIDC configuration: error performing request: Get \"invalid/.well-known/openid-configuration\": unsupported protocol scheme \"\"",
 | |
| 		}),
 | |
| 		Entry("with skip discovery and the JWKs URL is missing", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.SkipDiscovery = true
 | |
| 				p.JWKsURL = ""
 | |
| 			},
 | |
| 			expectedError: "invalid provider verifier options: missing required setting: jwks-url or public-key-files",
 | |
| 		}),
 | |
| 		Entry("with skip discovery, the JWKs URL not empty and len(PublicKeyFiles) is greater than 0", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.SkipDiscovery = true
 | |
| 				p.JWKsURL = "notEmpty"
 | |
| 				p.PublicKeyFiles = []string{"notEmpty"}
 | |
| 			},
 | |
| 			expectedError: "invalid provider verifier options: mutually exclusive settings: jwks-url and public-key-files",
 | |
| 		}),
 | |
| 		Entry("should be successful when skipping discovery with the JWKs URL specified", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.SkipDiscovery = true
 | |
| 				p.JWKsURL = m.JWKSEndpoint()
 | |
| 			},
 | |
| 		}),
 | |
| 		Entry("should pass when the key is valid", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.SkipDiscovery = true
 | |
| 				p.PublicKeyFiles = []string{validPublicKeyFilePath}
 | |
| 			},
 | |
| 		}),
 | |
| 		Entry("should fail when the key is invalid", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.SkipDiscovery = true
 | |
| 				p.PublicKeyFiles = []string{invalidPublicKeyFilePath}
 | |
| 			},
 | |
| 			expectedError: "could not get verifier builder: error while parsing public keys",
 | |
| 		}),
 | |
| 		Entry("should fail when the key file is not found", &newProviderVerifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.SkipDiscovery = true
 | |
| 				p.PublicKeyFiles = []string{"non-existing"}
 | |
| 			},
 | |
| 			expectedError: "could not get verifier builder: error while parsing public keys: failed to read file",
 | |
| 		}),
 | |
| 	)
 | |
| 
 | |
| 	type verifierTableInput struct {
 | |
| 		modifyOpts    func(*ProviderVerifierOptions)
 | |
| 		modifyClaims  func(claims *jwt.RegisteredClaims)
 | |
| 		expectedError string
 | |
| 	}
 | |
| 
 | |
| 	DescribeTable("when constructing the provider verifier", func(in *verifierTableInput) {
 | |
| 		opts := ProviderVerifierOptions{
 | |
| 			AudienceClaims: []string{"aud"},
 | |
| 			ClientID:       m.Config().ClientID,
 | |
| 			ExtraAudiences: []string{},
 | |
| 			IssuerURL:      m.Issuer(),
 | |
| 		}
 | |
| 		if in.modifyOpts != nil {
 | |
| 			in.modifyOpts(&opts)
 | |
| 		}
 | |
| 
 | |
| 		pv, err := NewProviderVerifier(context.Background(), opts)
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		now := time.Now()
 | |
| 		claims := jwt.RegisteredClaims{
 | |
| 			Audience:  jwt.ClaimStrings{m.Config().ClientID},
 | |
| 			Issuer:    m.Issuer(),
 | |
| 			ExpiresAt: jwt.NewNumericDate(now.Add(1 * time.Hour)),
 | |
| 			IssuedAt:  jwt.NewNumericDate(now),
 | |
| 			Subject:   "user",
 | |
| 		}
 | |
| 		if in.modifyClaims != nil {
 | |
| 			in.modifyClaims(&claims)
 | |
| 		}
 | |
| 
 | |
| 		rawIDToken, err := m.Keypair.SignJWT(claims)
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		idToken, err := pv.Verifier().Verify(context.Background(), rawIDToken)
 | |
| 		if in.expectedError != "" {
 | |
| 			Expect(err).To(MatchError(HavePrefix(in.expectedError)))
 | |
| 			return
 | |
| 		}
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		Expect(idToken.Issuer).To(Equal(claims.Issuer))
 | |
| 		Expect(idToken.Audience).To(ConsistOf(claims.Audience))
 | |
| 		Expect(idToken.Subject).To(Equal(claims.Subject))
 | |
| 	},
 | |
| 		Entry("with the default opts and claims", &verifierTableInput{}),
 | |
| 		Entry("when the audience is mismatched", &verifierTableInput{
 | |
| 			modifyClaims: func(j *jwt.RegisteredClaims) {
 | |
| 				j.Audience = jwt.ClaimStrings{"OtherClient"}
 | |
| 			},
 | |
| 			expectedError: "audience from claim aud with value [OtherClient] does not match with any of allowed audiences",
 | |
| 		}),
 | |
| 		Entry("when the audience is an extra audience", &verifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.ExtraAudiences = []string{"ExtraIssuer"}
 | |
| 			},
 | |
| 			modifyClaims: func(j *jwt.RegisteredClaims) {
 | |
| 				j.Audience = jwt.ClaimStrings{"ExtraIssuer"}
 | |
| 			},
 | |
| 		}),
 | |
| 		Entry("when the issuer is mismatched", &verifierTableInput{
 | |
| 			modifyClaims: func(j *jwt.RegisteredClaims) {
 | |
| 				j.Issuer = "OtherIssuer"
 | |
| 			},
 | |
| 			expectedError: "failed to verify token: oidc: id token issued by a different provider",
 | |
| 		}),
 | |
| 		Entry("when the issuer is mismatched with skip issuer verification", &verifierTableInput{
 | |
| 			modifyOpts: func(p *ProviderVerifierOptions) {
 | |
| 				p.SkipIssuerVerification = true
 | |
| 			},
 | |
| 			modifyClaims: func(j *jwt.RegisteredClaims) {
 | |
| 				j.Issuer = "OtherIssuer"
 | |
| 			},
 | |
| 		}),
 | |
| 		Entry("when the token has expired", &verifierTableInput{
 | |
| 			modifyClaims: func(j *jwt.RegisteredClaims) {
 | |
| 				j.ExpiresAt = jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
 | |
| 			},
 | |
| 			expectedError: "failed to verify token: oidc: token is expired",
 | |
| 		}),
 | |
| 	)
 | |
| })
 |