From 4980f6af7d86337aa66cfb3dfbebccb0d9df8b80 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Tue, 22 Jun 2021 18:50:47 -0700 Subject: [PATCH] Use upn claim as a fallback in Enrich & Refresh Only when `email` claim is missing, fallback to `upn` claim which may have it. --- providers/adfs.go | 66 ++++++++++++++++++++++++++++++---- providers/adfs_test.go | 82 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 138 insertions(+), 10 deletions(-) diff --git a/providers/adfs.go b/providers/adfs.go index 89119c7a..95177d76 100644 --- a/providers/adfs.go +++ b/providers/adfs.go @@ -1,14 +1,22 @@ package providers import ( + "context" + "fmt" "net/url" "strings" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" ) // ADFSProvider represents an ADFS based Identity Provider type ADFSProvider struct { *OIDCProvider + skipScope bool + // Expose for unit testing + oidcEnrichFunc func(context.Context, *sessions.SessionState) error + oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error) } var _ Provider = (*ADFSProvider)(nil) @@ -17,7 +25,7 @@ const ( adfsProviderName = "ADFS" adfsDefaultScope = "openid email profile" adfsSkipScope = false - adfsEmailClaim = "upn" + adfsUPNClaim = "upn" ) // NewADFSProvider initiates a new ADFSProvider @@ -26,7 +34,6 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider { name: adfsProviderName, scope: adfsDefaultScope, }) - p.EmailClaim = adfsEmailClaim if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { resource := p.ProtectedResource.String() @@ -39,12 +46,16 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider { } } + oidcProvider := &OIDCProvider{ + ProviderData: p, + SkipNonce: true, + } + return &ADFSProvider{ - OIDCProvider: &OIDCProvider{ - ProviderData: p, - SkipNonce: true, - }, - skipScope: adfsSkipScope, + OIDCProvider: oidcProvider, + skipScope: adfsSkipScope, + oidcEnrichFunc: oidcProvider.EnrichSession, + oidcRefreshFunc: oidcProvider.RefreshSession, } } @@ -68,3 +79,44 @@ func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce string) string { } return loginURL.String() } + +// EnrichSession calls the OIDC ProfileURL to backfill any fields missing +// from the claims. If Email is missing, falls back to ADFS `upn` claim. +func (p *ADFSProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { + err := p.oidcEnrichFunc(ctx, s) + if err != nil { + return err + } + + if s.Email == "" { + return p.fallbackUPN(ctx, s) + } + return nil +} + +// RefreshSession refreshes via the OIDC implementation. If email is missing, +// falls back to ADFS `upn` claim. +func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { + refreshed, err := p.oidcRefreshFunc(ctx, s) + if err != nil || s.Email != "" { + return refreshed, err + } + err = p.fallbackUPN(ctx, s) + return refreshed, err +} + +func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { + idToken, err := p.Verifier.Verify(ctx, s.IDToken) + if err != nil { + return err + } + claims, err := p.getClaims(idToken) + if err != nil { + return fmt.Errorf("couldn't extract claims from id_token (%v)", err) + } + upn := claims.raw[adfsUPNClaim] + if upn != nil { + s.Email = fmt.Sprint(upn) + } + return nil +} diff --git a/providers/adfs_test.go b/providers/adfs_test.go index ae306cd1..adbec455 100644 --- a/providers/adfs_test.go +++ b/providers/adfs_test.go @@ -2,6 +2,8 @@ package providers import ( "context" + "crypto/rand" + "crypto/rsa" "encoding/base64" "net/http" "net/http/httptest" @@ -9,6 +11,7 @@ import ( "strings" "github.com/coreos/go-oidc/v3/oidc" + "github.com/dgrijalva/jwt-go" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" @@ -25,8 +28,18 @@ func (fakeADFSJwks) VerifySignature(_ context.Context, jwt string) (payload []by return decodeString, nil } -func testADFSProvider(hostname string) *ADFSProvider { +type adfsClaims struct { + UPN string `json:"upn,omitempty"` + idTokenClaims +} +func newSignedTestADFSToken(tokenClaims adfsClaims) (string, error) { + key, _ := rsa.GenerateKey(rand.Reader, 2048) + standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims) + return standardClaims.SignedString(key) +} + +func testADFSProvider(hostname string) *ADFSProvider { o := oidc.NewVerifier( "https://issuer.example.com", fakeADFSJwks{}, @@ -41,6 +54,7 @@ func testADFSProvider(hostname string) *ADFSProvider { ValidateURL: &url.URL{}, Scope: "", Verifier: o, + EmailClaim: OIDCEmailClaim, }) if hostname != "" { @@ -54,7 +68,6 @@ func testADFSProvider(hostname string) *ADFSProvider { } func testADFSBackend() *httptest.Server { - authResponse := ` { "access_token": "my_access_token", @@ -129,7 +142,6 @@ var _ = Describe("ADFS Provider Tests", func() { Context("with valid token", func() { It("should not throw an error", func() { - p.EmailClaim = "email" rawIDToken, _ := newSignedTestIDToken(defaultIDToken) idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) Expect(err).To(BeNil()) @@ -202,4 +214,68 @@ var _ = Describe("ADFS Provider Tests", func() { }), ) }) + + Context("UPN Fallback", func() { + var idToken string + var session *sessions.SessionState + + BeforeEach(func() { + var err error + idToken, err = newSignedTestADFSToken(adfsClaims{ + UPN: "upn@company.com", + idTokenClaims: minimalIDToken, + }) + Expect(err).ToNot(HaveOccurred()) + + session = &sessions.SessionState{ + IDToken: idToken, + } + }) + + Describe("EnrichSession", func() { + It("uses email claim if present", func() { + p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { + s.Email = "person@company.com" + return nil + } + + err := p.EnrichSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("person@company.com")) + }) + + It("falls back to UPN claim if Email is missing", func() { + p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { + return nil + } + + err := p.EnrichSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("upn@company.com")) + }) + }) + + Describe("RefreshSession", func() { + It("uses email claim if present", func() { + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + s.Email = "person@company.com" + return true, nil + } + + _, err := p.RefreshSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("person@company.com")) + }) + + It("falls back to UPN claim if Email is missing", func() { + p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { + return true, nil + } + + _, err := p.RefreshSession(context.Background(), session) + Expect(err).ToNot(HaveOccurred()) + Expect(session.Email).To(Equal("upn@company.com")) + }) + }) + }) })