Use upn claim as a fallback in Enrich & Refresh
Only when `email` claim is missing, fallback to `upn` claim which may have it.
This commit is contained in:
		
							parent
							
								
									a53198725e
								
							
						
					
					
						commit
						4980f6af7d
					
				|  | @ -1,14 +1,22 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // ADFSProvider represents an ADFS based Identity Provider
 | ||||
| type ADFSProvider struct { | ||||
| 	*OIDCProvider | ||||
| 
 | ||||
| 	skipScope bool | ||||
| 	// Expose for unit testing
 | ||||
| 	oidcEnrichFunc  func(context.Context, *sessions.SessionState) error | ||||
| 	oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error) | ||||
| } | ||||
| 
 | ||||
| var _ Provider = (*ADFSProvider)(nil) | ||||
|  | @ -17,7 +25,7 @@ const ( | |||
| 	adfsProviderName = "ADFS" | ||||
| 	adfsDefaultScope = "openid email profile" | ||||
| 	adfsSkipScope    = false | ||||
| 	adfsEmailClaim   = "upn" | ||||
| 	adfsUPNClaim     = "upn" | ||||
| ) | ||||
| 
 | ||||
| // NewADFSProvider initiates a new ADFSProvider
 | ||||
|  | @ -26,7 +34,6 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider { | |||
| 		name:  adfsProviderName, | ||||
| 		scope: adfsDefaultScope, | ||||
| 	}) | ||||
| 	p.EmailClaim = adfsEmailClaim | ||||
| 
 | ||||
| 	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { | ||||
| 		resource := p.ProtectedResource.String() | ||||
|  | @ -39,12 +46,16 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider { | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return &ADFSProvider{ | ||||
| 		OIDCProvider: &OIDCProvider{ | ||||
| 	oidcProvider := &OIDCProvider{ | ||||
| 		ProviderData: p, | ||||
| 		SkipNonce:    true, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	return &ADFSProvider{ | ||||
| 		OIDCProvider:    oidcProvider, | ||||
| 		skipScope:       adfsSkipScope, | ||||
| 		oidcEnrichFunc:  oidcProvider.EnrichSession, | ||||
| 		oidcRefreshFunc: oidcProvider.RefreshSession, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -68,3 +79,44 @@ func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce string) string { | |||
| 	} | ||||
| 	return loginURL.String() | ||||
| } | ||||
| 
 | ||||
| // EnrichSession calls the OIDC ProfileURL to backfill any fields missing
 | ||||
| // from the claims. If Email is missing, falls back to ADFS `upn` claim.
 | ||||
| func (p *ADFSProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { | ||||
| 	err := p.oidcEnrichFunc(ctx, s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if s.Email == "" { | ||||
| 		return p.fallbackUPN(ctx, s) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RefreshSession refreshes via the OIDC implementation. If email is missing,
 | ||||
| // falls back to ADFS `upn` claim.
 | ||||
| func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 	refreshed, err := p.oidcRefreshFunc(ctx, s) | ||||
| 	if err != nil || s.Email != "" { | ||||
| 		return refreshed, err | ||||
| 	} | ||||
| 	err = p.fallbackUPN(ctx, s) | ||||
| 	return refreshed, err | ||||
| } | ||||
| 
 | ||||
| func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { | ||||
| 	idToken, err := p.Verifier.Verify(ctx, s.IDToken) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	claims, err := p.getClaims(idToken) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("couldn't extract claims from id_token (%v)", err) | ||||
| 	} | ||||
| 	upn := claims.raw[adfsUPNClaim] | ||||
| 	if upn != nil { | ||||
| 		s.Email = fmt.Sprint(upn) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
|  | @ -2,6 +2,8 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/rsa" | ||||
| 	"encoding/base64" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
|  | @ -9,6 +11,7 @@ import ( | |||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc/v3/oidc" | ||||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
|  | @ -25,8 +28,18 @@ func (fakeADFSJwks) VerifySignature(_ context.Context, jwt string) (payload []by | |||
| 	return decodeString, nil | ||||
| } | ||||
| 
 | ||||
| func testADFSProvider(hostname string) *ADFSProvider { | ||||
| type adfsClaims struct { | ||||
| 	UPN string `json:"upn,omitempty"` | ||||
| 	idTokenClaims | ||||
| } | ||||
| 
 | ||||
| func newSignedTestADFSToken(tokenClaims adfsClaims) (string, error) { | ||||
| 	key, _ := rsa.GenerateKey(rand.Reader, 2048) | ||||
| 	standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims) | ||||
| 	return standardClaims.SignedString(key) | ||||
| } | ||||
| 
 | ||||
| func testADFSProvider(hostname string) *ADFSProvider { | ||||
| 	o := oidc.NewVerifier( | ||||
| 		"https://issuer.example.com", | ||||
| 		fakeADFSJwks{}, | ||||
|  | @ -41,6 +54,7 @@ func testADFSProvider(hostname string) *ADFSProvider { | |||
| 		ValidateURL:  &url.URL{}, | ||||
| 		Scope:        "", | ||||
| 		Verifier:     o, | ||||
| 		EmailClaim:   OIDCEmailClaim, | ||||
| 	}) | ||||
| 
 | ||||
| 	if hostname != "" { | ||||
|  | @ -54,7 +68,6 @@ func testADFSProvider(hostname string) *ADFSProvider { | |||
| } | ||||
| 
 | ||||
| func testADFSBackend() *httptest.Server { | ||||
| 
 | ||||
| 	authResponse := ` | ||||
| 		{ | ||||
| 			"access_token": "my_access_token", | ||||
|  | @ -129,7 +142,6 @@ var _ = Describe("ADFS Provider Tests", func() { | |||
| 
 | ||||
| 	Context("with valid token", func() { | ||||
| 		It("should not throw an error", func() { | ||||
| 			p.EmailClaim = "email" | ||||
| 			rawIDToken, _ := newSignedTestIDToken(defaultIDToken) | ||||
| 			idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) | ||||
| 			Expect(err).To(BeNil()) | ||||
|  | @ -202,4 +214,68 @@ var _ = Describe("ADFS Provider Tests", func() { | |||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UPN Fallback", func() { | ||||
| 		var idToken string | ||||
| 		var session *sessions.SessionState | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			var err error | ||||
| 			idToken, err = newSignedTestADFSToken(adfsClaims{ | ||||
| 				UPN:           "upn@company.com", | ||||
| 				idTokenClaims: minimalIDToken, | ||||
| 			}) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			session = &sessions.SessionState{ | ||||
| 				IDToken: idToken, | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		Describe("EnrichSession", func() { | ||||
| 			It("uses email claim if present", func() { | ||||
| 				p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { | ||||
| 					s.Email = "person@company.com" | ||||
| 					return nil | ||||
| 				} | ||||
| 
 | ||||
| 				err := p.EnrichSession(context.Background(), session) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(session.Email).To(Equal("person@company.com")) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("falls back to UPN claim if Email is missing", func() { | ||||
| 				p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { | ||||
| 					return nil | ||||
| 				} | ||||
| 
 | ||||
| 				err := p.EnrichSession(context.Background(), session) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(session.Email).To(Equal("upn@company.com")) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Describe("RefreshSession", func() { | ||||
| 			It("uses email claim if present", func() { | ||||
| 				p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 					s.Email = "person@company.com" | ||||
| 					return true, nil | ||||
| 				} | ||||
| 
 | ||||
| 				_, err := p.RefreshSession(context.Background(), session) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(session.Email).To(Equal("person@company.com")) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("falls back to UPN claim if Email is missing", func() { | ||||
| 				p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 					return true, nil | ||||
| 				} | ||||
| 
 | ||||
| 				_, err := p.RefreshSession(context.Background(), session) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(session.Email).To(Equal("upn@company.com")) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue