Use upn claim as a fallback in Enrich & Refresh
Only when `email` claim is missing, fallback to `upn` claim which may have it.
This commit is contained in:
		
							parent
							
								
									a53198725e
								
							
						
					
					
						commit
						4980f6af7d
					
				|  | @ -1,14 +1,22 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // ADFSProvider represents an ADFS based Identity Provider
 | // ADFSProvider represents an ADFS based Identity Provider
 | ||||||
| type ADFSProvider struct { | type ADFSProvider struct { | ||||||
| 	*OIDCProvider | 	*OIDCProvider | ||||||
|  | 
 | ||||||
| 	skipScope bool | 	skipScope bool | ||||||
|  | 	// Expose for unit testing
 | ||||||
|  | 	oidcEnrichFunc  func(context.Context, *sessions.SessionState) error | ||||||
|  | 	oidcRefreshFunc func(context.Context, *sessions.SessionState) (bool, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| var _ Provider = (*ADFSProvider)(nil) | var _ Provider = (*ADFSProvider)(nil) | ||||||
|  | @ -17,7 +25,7 @@ const ( | ||||||
| 	adfsProviderName = "ADFS" | 	adfsProviderName = "ADFS" | ||||||
| 	adfsDefaultScope = "openid email profile" | 	adfsDefaultScope = "openid email profile" | ||||||
| 	adfsSkipScope    = false | 	adfsSkipScope    = false | ||||||
| 	adfsEmailClaim   = "upn" | 	adfsUPNClaim     = "upn" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // NewADFSProvider initiates a new ADFSProvider
 | // NewADFSProvider initiates a new ADFSProvider
 | ||||||
|  | @ -26,7 +34,6 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider { | ||||||
| 		name:  adfsProviderName, | 		name:  adfsProviderName, | ||||||
| 		scope: adfsDefaultScope, | 		scope: adfsDefaultScope, | ||||||
| 	}) | 	}) | ||||||
| 	p.EmailClaim = adfsEmailClaim |  | ||||||
| 
 | 
 | ||||||
| 	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { | 	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { | ||||||
| 		resource := p.ProtectedResource.String() | 		resource := p.ProtectedResource.String() | ||||||
|  | @ -39,12 +46,16 @@ func NewADFSProvider(p *ProviderData) *ADFSProvider { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	oidcProvider := &OIDCProvider{ | ||||||
|  | 		ProviderData: p, | ||||||
|  | 		SkipNonce:    true, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return &ADFSProvider{ | 	return &ADFSProvider{ | ||||||
| 		OIDCProvider: &OIDCProvider{ | 		OIDCProvider:    oidcProvider, | ||||||
| 			ProviderData: p, | 		skipScope:       adfsSkipScope, | ||||||
| 			SkipNonce:    true, | 		oidcEnrichFunc:  oidcProvider.EnrichSession, | ||||||
| 		}, | 		oidcRefreshFunc: oidcProvider.RefreshSession, | ||||||
| 		skipScope: adfsSkipScope, |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -68,3 +79,44 @@ func (p *ADFSProvider) GetLoginURL(redirectURI, state, nonce string) string { | ||||||
| 	} | 	} | ||||||
| 	return loginURL.String() | 	return loginURL.String() | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // EnrichSession calls the OIDC ProfileURL to backfill any fields missing
 | ||||||
|  | // from the claims. If Email is missing, falls back to ADFS `upn` claim.
 | ||||||
|  | func (p *ADFSProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { | ||||||
|  | 	err := p.oidcEnrichFunc(ctx, s) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if s.Email == "" { | ||||||
|  | 		return p.fallbackUPN(ctx, s) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // RefreshSession refreshes via the OIDC implementation. If email is missing,
 | ||||||
|  | // falls back to ADFS `upn` claim.
 | ||||||
|  | func (p *ADFSProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | 	refreshed, err := p.oidcRefreshFunc(ctx, s) | ||||||
|  | 	if err != nil || s.Email != "" { | ||||||
|  | 		return refreshed, err | ||||||
|  | 	} | ||||||
|  | 	err = p.fallbackUPN(ctx, s) | ||||||
|  | 	return refreshed, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *ADFSProvider) fallbackUPN(ctx context.Context, s *sessions.SessionState) error { | ||||||
|  | 	idToken, err := p.Verifier.Verify(ctx, s.IDToken) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	claims, err := p.getClaims(idToken) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("couldn't extract claims from id_token (%v)", err) | ||||||
|  | 	} | ||||||
|  | 	upn := claims.raw[adfsUPNClaim] | ||||||
|  | 	if upn != nil { | ||||||
|  | 		s.Email = fmt.Sprint(upn) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -2,6 +2,8 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"crypto/rsa" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | @ -9,6 +11,7 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/coreos/go-oidc/v3/oidc" | 	"github.com/coreos/go-oidc/v3/oidc" | ||||||
|  | 	"github.com/dgrijalva/jwt-go" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | @ -25,8 +28,18 @@ func (fakeADFSJwks) VerifySignature(_ context.Context, jwt string) (payload []by | ||||||
| 	return decodeString, nil | 	return decodeString, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func testADFSProvider(hostname string) *ADFSProvider { | type adfsClaims struct { | ||||||
|  | 	UPN string `json:"upn,omitempty"` | ||||||
|  | 	idTokenClaims | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
|  | func newSignedTestADFSToken(tokenClaims adfsClaims) (string, error) { | ||||||
|  | 	key, _ := rsa.GenerateKey(rand.Reader, 2048) | ||||||
|  | 	standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims) | ||||||
|  | 	return standardClaims.SignedString(key) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func testADFSProvider(hostname string) *ADFSProvider { | ||||||
| 	o := oidc.NewVerifier( | 	o := oidc.NewVerifier( | ||||||
| 		"https://issuer.example.com", | 		"https://issuer.example.com", | ||||||
| 		fakeADFSJwks{}, | 		fakeADFSJwks{}, | ||||||
|  | @ -41,6 +54,7 @@ func testADFSProvider(hostname string) *ADFSProvider { | ||||||
| 		ValidateURL:  &url.URL{}, | 		ValidateURL:  &url.URL{}, | ||||||
| 		Scope:        "", | 		Scope:        "", | ||||||
| 		Verifier:     o, | 		Verifier:     o, | ||||||
|  | 		EmailClaim:   OIDCEmailClaim, | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if hostname != "" { | 	if hostname != "" { | ||||||
|  | @ -54,7 +68,6 @@ func testADFSProvider(hostname string) *ADFSProvider { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func testADFSBackend() *httptest.Server { | func testADFSBackend() *httptest.Server { | ||||||
| 
 |  | ||||||
| 	authResponse := ` | 	authResponse := ` | ||||||
| 		{ | 		{ | ||||||
| 			"access_token": "my_access_token", | 			"access_token": "my_access_token", | ||||||
|  | @ -129,7 +142,6 @@ var _ = Describe("ADFS Provider Tests", func() { | ||||||
| 
 | 
 | ||||||
| 	Context("with valid token", func() { | 	Context("with valid token", func() { | ||||||
| 		It("should not throw an error", func() { | 		It("should not throw an error", func() { | ||||||
| 			p.EmailClaim = "email" |  | ||||||
| 			rawIDToken, _ := newSignedTestIDToken(defaultIDToken) | 			rawIDToken, _ := newSignedTestIDToken(defaultIDToken) | ||||||
| 			idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) | 			idToken, err := p.Verifier.Verify(context.Background(), rawIDToken) | ||||||
| 			Expect(err).To(BeNil()) | 			Expect(err).To(BeNil()) | ||||||
|  | @ -202,4 +214,68 @@ var _ = Describe("ADFS Provider Tests", func() { | ||||||
| 			}), | 			}), | ||||||
| 		) | 		) | ||||||
| 	}) | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UPN Fallback", func() { | ||||||
|  | 		var idToken string | ||||||
|  | 		var session *sessions.SessionState | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			var err error | ||||||
|  | 			idToken, err = newSignedTestADFSToken(adfsClaims{ | ||||||
|  | 				UPN:           "upn@company.com", | ||||||
|  | 				idTokenClaims: minimalIDToken, | ||||||
|  | 			}) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			session = &sessions.SessionState{ | ||||||
|  | 				IDToken: idToken, | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Describe("EnrichSession", func() { | ||||||
|  | 			It("uses email claim if present", func() { | ||||||
|  | 				p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { | ||||||
|  | 					s.Email = "person@company.com" | ||||||
|  | 					return nil | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				err := p.EnrichSession(context.Background(), session) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(session.Email).To(Equal("person@company.com")) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("falls back to UPN claim if Email is missing", func() { | ||||||
|  | 				p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error { | ||||||
|  | 					return nil | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				err := p.EnrichSession(context.Background(), session) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(session.Email).To(Equal("upn@company.com")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Describe("RefreshSession", func() { | ||||||
|  | 			It("uses email claim if present", func() { | ||||||
|  | 				p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | 					s.Email = "person@company.com" | ||||||
|  | 					return true, nil | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				_, err := p.RefreshSession(context.Background(), session) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(session.Email).To(Equal("person@company.com")) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("falls back to UPN claim if Email is missing", func() { | ||||||
|  | 				p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) { | ||||||
|  | 					return true, nil | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				_, err := p.RefreshSession(context.Background(), session) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(session.Email).To(Equal("upn@company.com")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
| }) | }) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue