diff --git a/CHANGELOG.md b/CHANGELOG.md index 8defa4f8..b94310f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.2.0 +- [#1247](https://github.com/oauth2-proxy/oauth2-proxy/pull/1247) Use `upn` claim consistently in ADFSProvider (@NickMeves) - [#1447](https://github.com/oauth2-proxy/oauth2-proxy/pull/1447) Fix docker build/push issues found during last release (@JoelSpeed) - [#1433](https://github.com/oauth2-proxy/oauth2-proxy/pull/1433) Let authentication fail when session validation fails (@stippi2) - [#1445](https://github.com/oauth2-proxy/oauth2-proxy/pull/1445) Fix docker container multi arch build issue by passing GOARCH details to make build (@jkandasa) diff --git a/providers/adfs.go b/providers/adfs.go index c626bd09..797c8566 100644 --- a/providers/adfs.go +++ b/providers/adfs.go @@ -2,7 +2,6 @@ package providers import ( "context" - "errors" "fmt" "net/url" "strings" @@ -13,23 +12,27 @@ import ( // ADFSProvider represents an ADFS based Identity Provider type ADFSProvider struct { *OIDCProvider - SkipScope bool + + skipScope bool + // Expose for unit testing + oidcEnrichFunc func(context.Context, *sessions.SessionState) error + oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error) } var _ Provider = (*ADFSProvider)(nil) const ( - ADFSProviderName = "ADFS" - ADFSDefaultScope = "openid email profile" - ADFSSkipScope = false + adfsProviderName = "ADFS" + adfsDefaultScope = "openid email profile" + adfsSkipScope = false + adfsUPNClaim = "upn" ) // NewADFSProvider initiates a new ADFSProvider func NewADFSProvider(p *ProviderData) *ADFSProvider { - p.setProviderDefaults(providerDefaults{ - name: ADFSProviderName, - scope: ADFSDefaultScope, + name: adfsProviderName, + scope: adfsDefaultScope, }) if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { @@ -43,18 +46,22 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider { } } + oidcProvider := &OIDCProvider{ + ProviderData: p, + SkipNonce: false, + } + return &ADFSProvider{ - OIDCProvider: &OIDCProvider{ - ProviderData: p, - SkipNonce: true, - }, - SkipScope: ADFSSkipScope, + OIDCProvider: oidcProvider, + skipScope: adfsSkipScope, + oidcEnrichFunc: oidcProvider.EnrichSession, + oidcRefreshFunc: oidcProvider.RefreshSession, } } // Configure defaults the ADFSProvider configuration options func (p *ADFSProvider) Configure(skipScope bool) { - p.SkipScope = skipScope + p.skipScope = skipScope } // GetLoginURL Override to double encode the state parameter. If not query params are lost @@ -65,7 +72,7 @@ func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce string) string { extraParams.Add("nonce", nonce) } loginURL := makeLoginURL(p.Data(), redirectURI, url.QueryEscape(state), extraParams) - if p.SkipScope { + if p.skipScope { q := loginURL.Query() q.Del("scope") loginURL.RawQuery = q.Encode() @@ -73,28 +80,40 @@ func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce string) string { return loginURL.String() } -// EnrichSession to add email +// EnrichSession calls the OIDC ProfileURL to backfill any fields missing +// from the claims. If Email is missing, falls back to ADFS `upn` claim. func (p *ADFSProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { - if s.Email != "" { - return nil + err := p.oidcEnrichFunc(ctx, s) + if err != nil || s.Email == "" { + // OIDC only errors if email is missing + return p.fallbackUPN(ctx, s) } + return nil +} +// RefreshSession refreshes via the OIDC implementation. If email is missing, +// falls back to ADFS `upn` claim. +func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + refreshed, err := p.oidcRefreshFunc(ctx, s) + if err != nil || s.Email != "" { + return refreshed, err + } + err = p.fallbackUPN(ctx, s) + return refreshed, err +} + +func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { idToken, err := p.Verifier.Verify(ctx, s.IDToken) if err != nil { return err } - - p.EmailClaim = "upn" - c, err := p.getClaims(idToken) - + claims, err := p.getClaims(idToken) if err != nil { return fmt.Errorf("couldn't extract claims from id_token (%v)", err) } - s.Email = c.Email - - if s.Email == "" { - err = errors.New("email not set") + upn := claims.raw[adfsUPNClaim] + if upn != nil { + s.Email = fmt.Sprint(upn) } - - return err + return nil } diff --git a/providers/adfs_test.go b/providers/adfs_test.go index bd2a82ce..a12a4e53 100644 --- a/providers/adfs_test.go +++ b/providers/adfs_test.go @@ -2,13 +2,17 @@ package providers import ( "context" + "crypto/rand" + "crypto/rsa" "encoding/base64" + "errors" "net/http" "net/http/httptest" "net/url" "strings" "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" @@ -25,8 +29,18 @@ func (fakeADFSJwks) VerifySignature(_ context.Context, jwt string) (payload []by return decodeString, nil } -func testADFSProvider(hostname string) *ADFSProvider { +type adfsClaims struct { + UPN string `json:"upn,omitempty"` + idTokenClaims +} +func newSignedTestADFSToken(tokenClaims adfsClaims) (string, error) { + key, _ := rsa.GenerateKey(rand.Reader, 2048) + standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims) + return standardClaims.SignedString(key) +} + +func testADFSProvider(hostname string) *ADFSProvider { o := oidc.NewVerifier( "https://issuer.example.com", fakeADFSJwks{}, @@ -41,6 +55,7 @@ func testADFSProvider(hostname string) *ADFSProvider { ValidateURL: &url.URL{}, Scope: "", Verifier: o, + EmailClaim: OIDCEmailClaim, }) if hostname != "" { @@ -54,7 +69,6 @@ func testADFSProvider(hostname string) *ADFSProvider { } func testADFSBackend() *httptest.Server { - authResponse := ` { "access_token": "my_access_token", @@ -129,13 +143,12 @@ var _ = Describe("ADFS Provider Tests", func() { Context("with valid token", func() { It("should not throw an error", func() { - p.EmailClaim = "email" rawIDToken, _ := newSignedTestIDToken(defaultIDToken) idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) Expect(err).To(BeNil()) session, err := p.buildSessionFromClaims(idToken) - session.IDToken = rawIDToken Expect(err).To(BeNil()) + session.IDToken = rawIDToken err = p.EnrichSession(context.Background(), session) Expect(session.Email).To(Equal("janed@me.com")) Expect(err).To(BeNil()) @@ -149,7 +162,7 @@ var _ = Describe("ADFS Provider Tests", func() { ProtectedResource: resource, Scope: "", }) - p.SkipScope = true + p.skipScope = true result := p.GetLoginURL("https://example.com/adfs/oauth2/", "", "") Expect(result).NotTo(ContainSubstring("scope=")) @@ -202,4 +215,78 @@ var _ = Describe("ADFS Provider Tests", func() { }), ) }) + + Context("UPN Fallback", func() { + var idToken string + var session *sessions.SessionState + + BeforeEach(func() { + var err error + idToken, err = newSignedTestADFSToken(adfsClaims{ + UPN: "upn@company.com", + idTokenClaims: minimalIDToken, + }) + Expect(err).ToNot(HaveOccurred()) + + session = &sessions.SessionState{ + IDToken: idToken, + } + }) + + Describe("EnrichSession", func() { + It("uses email claim if present", func() { + p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { + s.Email = "person@company.com" + return nil + } + + err := p.EnrichSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("person@company.com")) + }) + + It("falls back to UPN claim if Email is missing", func() { + p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { + return nil + } + + err := p.EnrichSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("upn@company.com")) + }) + + It("falls back to UPN claim on errors", func() { + p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { + return errors.New("neither the id_token nor the profileURL set an email") + } + + err := p.EnrichSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("upn@company.com")) + }) + }) + + Describe("RefreshSession", func() { + It("uses email claim if present", func() { + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + s.Email = "person@company.com" + return true, nil + } + + _, err := p.RefreshSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("person@company.com")) + }) + + It("falls back to UPN claim if Email is missing", func() { + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + return true, nil + } + + _, err := p.RefreshSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("upn@company.com")) + }) + }) + }) })