154 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			154 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
| package oidc
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"encoding/json"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 
 | |
| 	"github.com/oauth2-proxy/mockoidc"
 | |
| 	. "github.com/onsi/ginkgo"
 | |
| 	. "github.com/onsi/ginkgo/extensions/table"
 | |
| 	. "github.com/onsi/gomega"
 | |
| )
 | |
| 
 | |
| var _ = Describe("Provider", func() {
 | |
| 	type newProviderTableInput struct {
 | |
| 		skipIssuerVerification bool
 | |
| 		expectedError          string
 | |
| 		middlewares            func(*mockoidc.MockOIDC) []func(http.Handler) http.Handler
 | |
| 	}
 | |
| 
 | |
| 	DescribeTable("NewProvider", func(in *newProviderTableInput) {
 | |
| 		m, err := mockoidc.NewServer(nil)
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		if in.middlewares != nil {
 | |
| 			middlewares := in.middlewares(m)
 | |
| 			for _, middlware := range middlewares {
 | |
| 				m.AddMiddleware(middlware)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		ln, err := net.Listen("tcp", "127.0.0.1:0")
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		Expect(m.Start(ln, nil)).To(Succeed())
 | |
| 		defer func() {
 | |
| 			Expect(m.Shutdown()).To(Succeed())
 | |
| 		}()
 | |
| 
 | |
| 		provider, err := NewProvider(context.Background(), m.Issuer(), in.skipIssuerVerification)
 | |
| 		if in.expectedError != "" {
 | |
| 			Expect(err).To(MatchError(HavePrefix(in.expectedError)))
 | |
| 			return
 | |
| 		}
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		endpoints := provider.Endpoints()
 | |
| 		Expect(endpoints.AuthURL).To(Equal(m.AuthorizationEndpoint()))
 | |
| 		Expect(endpoints.TokenURL).To(Equal(m.TokenEndpoint()))
 | |
| 		Expect(endpoints.JWKsURL).To(Equal(m.JWKSEndpoint()))
 | |
| 		Expect(endpoints.UserInfoURL).To(Equal(m.UserinfoEndpoint()))
 | |
| 	},
 | |
| 		Entry("with issuer verification and the issuer matches", &newProviderTableInput{
 | |
| 			skipIssuerVerification: false,
 | |
| 		}),
 | |
| 		Entry("with skip issuer verification and the issuer matches", &newProviderTableInput{
 | |
| 			skipIssuerVerification: true,
 | |
| 		}),
 | |
| 		Entry("with issuer verification and an invalid issuer", &newProviderTableInput{
 | |
| 			skipIssuerVerification: false,
 | |
| 			middlewares: func(m *mockoidc.MockOIDC) []func(http.Handler) http.Handler {
 | |
| 				return []func(http.Handler) http.Handler{
 | |
| 					newInvalidIssuerMiddleware(m),
 | |
| 				}
 | |
| 			},
 | |
| 			expectedError: "oidc: issuer did not match the issuer returned by provider",
 | |
| 		}),
 | |
| 		Entry("with skip issuer verification and an invalid issuer", &newProviderTableInput{
 | |
| 			skipIssuerVerification: true,
 | |
| 			middlewares: func(m *mockoidc.MockOIDC) []func(http.Handler) http.Handler {
 | |
| 				return []func(http.Handler) http.Handler{
 | |
| 					newInvalidIssuerMiddleware(m),
 | |
| 				}
 | |
| 			},
 | |
| 		}),
 | |
| 		Entry("when the issuer returns a bad response", &newProviderTableInput{
 | |
| 			skipIssuerVerification: false,
 | |
| 			middlewares: func(m *mockoidc.MockOIDC) []func(http.Handler) http.Handler {
 | |
| 				return []func(http.Handler) http.Handler{
 | |
| 					newBadRequestMiddleware(),
 | |
| 				}
 | |
| 			},
 | |
| 			expectedError: "failed to discover OIDC configuration: unexpected status \"400\"",
 | |
| 		}),
 | |
| 	)
 | |
| 
 | |
| 	It("with code challenges supported on the provider, shold populate PKCE information", func() {
 | |
| 		m, err := mockoidc.NewServer(nil)
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 		m.AddMiddleware(newCodeChallengeIssuerMiddleware(m))
 | |
| 
 | |
| 		ln, err := net.Listen("tcp", "127.0.0.1:0")
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		Expect(m.Start(ln, nil)).To(Succeed())
 | |
| 		defer func() {
 | |
| 			Expect(m.Shutdown()).To(Succeed())
 | |
| 		}()
 | |
| 
 | |
| 		provider, err := NewProvider(context.Background(), m.Issuer(), false)
 | |
| 		Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 		Expect(provider.PKCE().CodeChallengeAlgs).To(ConsistOf("S256", "plain"))
 | |
| 	})
 | |
| })
 | |
| 
 | |
| func newInvalidIssuerMiddleware(m *mockoidc.MockOIDC) func(http.Handler) http.Handler {
 | |
| 	return func(next http.Handler) http.Handler {
 | |
| 		return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 			p := providerJSON{
 | |
| 				Issuer:      "invalid",
 | |
| 				AuthURL:     m.AuthorizationEndpoint(),
 | |
| 				TokenURL:    m.TokenEndpoint(),
 | |
| 				JWKsURL:     m.JWKSEndpoint(),
 | |
| 				UserInfoURL: m.UserinfoEndpoint(),
 | |
| 			}
 | |
| 			data, err := json.Marshal(p)
 | |
| 			if err != nil {
 | |
| 				rw.WriteHeader(500)
 | |
| 			}
 | |
| 			rw.Write(data)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func newCodeChallengeIssuerMiddleware(m *mockoidc.MockOIDC) func(http.Handler) http.Handler {
 | |
| 	return func(next http.Handler) http.Handler {
 | |
| 		return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 			p := providerJSON{
 | |
| 				Issuer:            m.Issuer(),
 | |
| 				AuthURL:           m.AuthorizationEndpoint(),
 | |
| 				TokenURL:          m.TokenEndpoint(),
 | |
| 				JWKsURL:           m.JWKSEndpoint(),
 | |
| 				UserInfoURL:       m.UserinfoEndpoint(),
 | |
| 				CodeChallengeAlgs: []string{"S256", "plain"},
 | |
| 			}
 | |
| 			data, err := json.Marshal(p)
 | |
| 			if err != nil {
 | |
| 				rw.WriteHeader(500)
 | |
| 			}
 | |
| 			rw.Write(data)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func newBadRequestMiddleware() func(http.Handler) http.Handler {
 | |
| 	return func(next http.Handler) http.Handler {
 | |
| 		return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | |
| 			rw.WriteHeader(400)
 | |
| 		})
 | |
| 	}
 | |
| }
 |