From 4c2bf5a2fe8cd751f7f93234a2c880c249a341fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nils=20Gustav=20Str=C3=A5b=C3=B8?= <65334626+nilsgstrabo@users.noreply.github.com> Date: Sat, 20 Jan 2024 20:51:42 +0100 Subject: [PATCH] Feature/add option to skip loading claims from profile url (#2329) * add new flag skip-claims-from-profile-url * skip passing profile URL if SkipClaimsFromProfileURL * docs for --skip-claims-from-profile-url flag * update flag comment * update docs * update CHANGELOG.md * Update providers/provider_data.go Co-authored-by: Jan Larwig * Add tests for SkipClaimsFromProfileURL * simplify tests for SkipClaimsFromProfileURL * generate alpha_config.md --------- Co-authored-by: Jan Larwig --- CHANGELOG.md | 1 + docs/docs/configuration/alpha_config.md | 1 + docs/docs/configuration/overview.md | 1 + pkg/apis/options/legacy_options.go | 31 ++++++++------- pkg/apis/options/providers.go | 3 ++ providers/provider_data.go | 18 ++++++--- providers/provider_data_test.go | 50 +++++++++++++++++++++---- providers/providers.go | 1 + 8 files changed, 78 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 634dcdfd..9d04ebf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ - [#1866](https://github.com/oauth2-proxy/oauth2-proxy/pull/1866) Add support for unix socker as upstream (@babs) - [#1949](https://github.com/oauth2-proxy/oauth2-proxy/pull/1949) Allow cookie names with dots in redis sessions (@miguelborges99) - [#2297](https://github.com/oauth2-proxy/oauth2-proxy/pull/2297) Add nightly build and push (@tuunit) +- [#2329](https://github.com/oauth2-proxy/oauth2-proxy/pull/2329) Add an option to skip request to profile URL for resolving missing claims in id_token (@nilsgstrabo) - [#2299](https://github.com/oauth2-proxy/oauth2-proxy/pull/2299) bugfix: OIDCConfig based providers are not respecting flags and configs (@tuunit) - [#2343](https://github.com/oauth2-proxy/oauth2-proxy/pull/2343) chore: Added checksums for .tar.gz (@kvanzuijlen) - [#2248](https://github.com/oauth2-proxy/oauth2-proxy/pull/2248) Added support for semicolons in query strings. (@timwsuqld) diff --git a/docs/docs/configuration/alpha_config.md b/docs/docs/configuration/alpha_config.md index 7143ebb5..9c1ccdd9 100644 --- a/docs/docs/configuration/alpha_config.md +++ b/docs/docs/configuration/alpha_config.md @@ -434,6 +434,7 @@ Provider holds all configuration for a single provider | `loginURLParameters` | _[[]LoginURLParameter](#loginurlparameter)_ | LoginURLParameters defines the parameters that can be passed from the start URL to the IdP login URL | | `redeemURL` | _string_ | RedeemURL is the token redemption endpoint | | `profileURL` | _string_ | ProfileURL is the profile access endpoint | +| `skipClaimsFromProfileURL` | _bool_ | SkipClaimsFromProfileURL allows to skip request to Profile URL for resolving claims not present in id_token
default set to 'false' | | `resource` | _string_ | ProtectedResource is the resource that is protected (Azure AD and ADFS only) | | `validateURL` | _string_ | ValidateURL is the access token validation endpoint | | `scope` | _string_ | Scope is the OAuth scope specification | diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index aaa3dfd9..5fe9350b 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -144,6 +144,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | `--pass-host-header` | bool | pass the request Host Header to upstream | true | | `--pass-user-headers` | bool | pass X-Forwarded-User, X-Forwarded-Groups, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true | | `--profile-url` | string | Profile access endpoint | | +| `--skip-claims-from-profile-url` | bool | skip request to Profile URL for resolving claims not present in id_token | false | | `--prompt` | string | [OIDC prompt](https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest); if present, `approval-prompt` is ignored | `""` | | `--provider` | string | OAuth provider | google | | `--provider-ca-file` | string \| list | Paths to CA certificates that should be used when connecting to the provider. If not specified, the default Go trust sources are used instead. | diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index f5747f02..6f032365 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -523,6 +523,7 @@ type LegacyProvider struct { LoginURL string `flag:"login-url" cfg:"login_url"` RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` ProfileURL string `flag:"profile-url" cfg:"profile_url"` + SkipClaimsFromProfileURL bool `flag:"skip-claims-from-profile-url" cfg:"skip_claims_from_profile_url"` ProtectedResource string `flag:"resource" cfg:"resource"` ValidateURL string `flag:"validate-url" cfg:"validate_url"` Scope string `flag:"scope" cfg:"scope"` @@ -578,6 +579,7 @@ func legacyProviderFlagSet() *pflag.FlagSet { flagSet.String("login-url", "", "Authentication endpoint") flagSet.String("redeem-url", "", "Token redemption endpoint") flagSet.String("profile-url", "", "Profile access endpoint") + flagSet.Bool("skip-claims-from-profile-url", false, "Skip loading missing claims from profile URL") flagSet.String("resource", "", "The resource that is protected (Azure AD only)") flagSet.String("validate-url", "", "Access token validation endpoint") flagSet.String("scope", "", "OAuth scope specification") @@ -658,20 +660,21 @@ func (l *LegacyProvider) convert() (Providers, error) { providers := Providers{} provider := Provider{ - ClientID: l.ClientID, - ClientSecret: l.ClientSecret, - ClientSecretFile: l.ClientSecretFile, - Type: ProviderType(l.ProviderType), - CAFiles: l.ProviderCAFiles, - UseSystemTrustStore: l.UseSystemTrustStore, - LoginURL: l.LoginURL, - RedeemURL: l.RedeemURL, - ProfileURL: l.ProfileURL, - ProtectedResource: l.ProtectedResource, - ValidateURL: l.ValidateURL, - Scope: l.Scope, - AllowedGroups: l.AllowedGroups, - CodeChallengeMethod: l.CodeChallengeMethod, + ClientID: l.ClientID, + ClientSecret: l.ClientSecret, + ClientSecretFile: l.ClientSecretFile, + Type: ProviderType(l.ProviderType), + CAFiles: l.ProviderCAFiles, + UseSystemTrustStore: l.UseSystemTrustStore, + LoginURL: l.LoginURL, + RedeemURL: l.RedeemURL, + ProfileURL: l.ProfileURL, + SkipClaimsFromProfileURL: l.SkipClaimsFromProfileURL, + ProtectedResource: l.ProtectedResource, + ValidateURL: l.ValidateURL, + Scope: l.Scope, + AllowedGroups: l.AllowedGroups, + CodeChallengeMethod: l.CodeChallengeMethod, } // This part is out of the switch section for all providers that support OIDC diff --git a/pkg/apis/options/providers.go b/pkg/apis/options/providers.go index 8820f345..9772a09e 100644 --- a/pkg/apis/options/providers.go +++ b/pkg/apis/options/providers.go @@ -70,6 +70,9 @@ type Provider struct { RedeemURL string `json:"redeemURL,omitempty"` // ProfileURL is the profile access endpoint ProfileURL string `json:"profileURL,omitempty"` + // SkipClaimsFromProfileURL allows to skip request to Profile URL for resolving claims not present in id_token + // default set to 'false' + SkipClaimsFromProfileURL bool `json:"skipClaimsFromProfileURL,omitempty"` // ProtectedResource is the resource that is protected (Azure AD and ADFS only) ProtectedResource string `json:"resource,omitempty"` // ValidateURL is the access token validation endpoint diff --git a/providers/provider_data.go b/providers/provider_data.go index 0e3f090f..a5fd7c2d 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -43,11 +43,12 @@ type ProviderData struct { SupportedCodeChallengeMethods []string `json:"code_challenge_methods_supported,omitempty"` // Common OIDC options for any OIDC-based providers to consume - AllowUnverifiedEmail bool - UserClaim string - EmailClaim string - GroupsClaim string - Verifier internaloidc.IDTokenVerifier + AllowUnverifiedEmail bool + UserClaim string + EmailClaim string + GroupsClaim string + Verifier internaloidc.IDTokenVerifier + SkipClaimsFromProfileURL bool // Universal Group authorization data structure // any provider can set to consume @@ -283,7 +284,12 @@ func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (* } func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.ClaimExtractor, error) { - extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, p.ProfileURL, p.getAuthorizationHeader(accessToken)) + profileURL := p.ProfileURL + if p.SkipClaimsFromProfileURL { + profileURL = &url.URL{} + } + + extractor, err := util.NewClaimExtractor(context.TODO(), rawIDToken, profileURL, p.getAuthorizationHeader(accessToken)) if err != nil { return nil, fmt.Errorf("could not initialise claim extractor: %v", err) } diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index 838c061b..6bdcf5cc 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -8,6 +8,8 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "net/http/httptest" "net/url" "strings" "testing" @@ -233,13 +235,16 @@ func TestProviderData_verifyIDToken(t *testing.T) { func TestProviderData_buildSessionFromClaims(t *testing.T) { testCases := map[string]struct { - IDToken idTokenClaims - AllowUnverified bool - UserClaim string - EmailClaim string - GroupsClaim string - ExpectedError error - ExpectedSession *sessions.SessionState + IDToken idTokenClaims + AllowUnverified bool + UserClaim string + EmailClaim string + GroupsClaim string + SkipClaimsFromProfileURL bool + SetProfileURL bool + ExpectedError error + ExpectedSession *sessions.SessionState + ExpectProfileURLCalled bool }{ "Standard": { IDToken: defaultIDToken, @@ -408,11 +413,36 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { PreferredUsername: "Jane Dobbs", }, }, + "Request claims from ProfileURL": { + IDToken: minimalIDToken, + SetProfileURL: true, + ExpectProfileURLCalled: true, + ExpectedSession: &sessions.SessionState{}, + }, + "Skip claims request to ProfileURL": { + IDToken: minimalIDToken, + SetProfileURL: true, + SkipClaimsFromProfileURL: true, + ExpectedSession: &sessions.SessionState{}, + }, } for testName, tc := range testCases { t.Run(testName, func(t *testing.T) { g := NewWithT(t) + var ( + profileURL *url.URL + profileURLCalled bool + ) + if tc.SetProfileURL { + profileURLSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + profileURLCalled = true + w.Write([]byte("{}")) + })) + defer profileURLSrv.Close() + profileURL, _ = url.Parse(profileURLSrv.URL) + } + verificationOptions := internaloidc.IDTokenVerificationOptions{ AudienceClaims: []string{"aud"}, ClientID: oidcClientID, @@ -423,22 +453,26 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { mockJWKS{}, &oidc.Config{ClientID: oidcClientID}, ), verificationOptions), + ProfileURL: profileURL, + getAuthorizationHeaderFunc: func(s string) http.Header { return http.Header{} }, } provider.AllowUnverifiedEmail = tc.AllowUnverified provider.UserClaim = tc.UserClaim provider.EmailClaim = tc.EmailClaim provider.GroupsClaim = tc.GroupsClaim + provider.SkipClaimsFromProfileURL = tc.SkipClaimsFromProfileURL rawIDToken, err := newSignedTestIDToken(tc.IDToken) g.Expect(err).ToNot(HaveOccurred()) - ss, err := provider.buildSessionFromClaims(rawIDToken, "") + ss, err := provider.buildSessionFromClaims(rawIDToken, "testtoken") if err != nil { g.Expect(err).To(Equal(tc.ExpectedError)) } if ss != nil { g.Expect(ss).To(Equal(tc.ExpectedSession)) } + g.Expect(profileURLCalled).To(Equal(tc.ExpectProfileURLCalled)) }) } } diff --git a/providers/providers.go b/providers/providers.go index 192b86a0..1290642a 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -138,6 +138,7 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, p.AllowUnverifiedEmail = providerConfig.OIDCConfig.InsecureAllowUnverifiedEmail p.EmailClaim = providerConfig.OIDCConfig.EmailClaim p.GroupsClaim = providerConfig.OIDCConfig.GroupsClaim + p.SkipClaimsFromProfileURL = providerConfig.SkipClaimsFromProfileURL // Set PKCE enabled or disabled based on discovery and force options p.CodeChallengeMethod = parseCodeChallengeMethod(providerConfig)