diff --git a/CHANGELOG.md b/CHANGELOG.md index cad157f1..0470479c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ - [#3352](https://github.com/oauth2-proxy/oauth2-proxy/pull/3352) fix: backend logout URL call on sign out (#3172)(@vsejpal) - [#3332](https://github.com/oauth2-proxy/oauth2-proxy/pull/3332) ci: distribute windows binary with .exe extension (@igitur) -- [#2685](https://github.com/oauth2-proxy/oauth2-proxy/pull/2685) feat: allow arbitrary claims from the IDToken to be injected into the response header (@vegetablest) +- [#2685](https://github.com/oauth2-proxy/oauth2-proxy/pull/2685) feat: allow arbitrary claims from the IDToken and IdentityProvider UserInfo endpoint to be added to the session state (@vegetablest) # V7.14.3 diff --git a/pkg/apis/options/providers.go b/pkg/apis/options/providers.go index f734469e..55965ed9 100644 --- a/pkg/apis/options/providers.go +++ b/pkg/apis/options/providers.go @@ -135,7 +135,7 @@ type Provider struct { CodeChallengeMethod string `yaml:"code_challenge_method,omitempty"` // Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured. - AdditionalClaims []string `json:"additionalClaims,omitempty"` + AdditionalClaims []string `yaml:"additionalClaims,omitempty"` // URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session BackendLogoutURL string `yaml:"backendLogoutURL"` diff --git a/pkg/providers/util/claim_extractor.go b/pkg/providers/util/claim_extractor.go index 913d17ab..469a4f54 100644 --- a/pkg/providers/util/claim_extractor.go +++ b/pkg/providers/util/claim_extractor.go @@ -18,10 +18,10 @@ import ( // present, from the profile URL. type ClaimExtractor interface { // GetClaim fetches a named claim and returns the value. - GetClaim(claim string) (interface{}, bool, error) + GetClaim(claim string) (any, bool, error) // GetClaimInto fetches a named claim and puts the value into the destination. - GetClaimInto(claim string, dst interface{}) (bool, error) + GetClaimInto(claim string, dst any) (bool, error) } // NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token. @@ -30,12 +30,12 @@ type ClaimExtractor interface { func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) { payload, err := parseJWT(idToken) if err != nil { - return nil, fmt.Errorf("failed to parse ID Token: %v", err) + return nil, fmt.Errorf("failed to parse ID Token: %w", err) } tokenClaims, err := simplejson.NewJson(payload) if err != nil { - return nil, fmt.Errorf("failed to parse ID Token payload: %v", err) + return nil, fmt.Errorf("failed to parse ID Token payload: %w", err) } return &claimExtractor{ @@ -58,7 +58,7 @@ type claimExtractor struct { // GetClaim will return the value claim if it exists. // It will only return an error if the profile URL needs to be fetched due to // the claim not being present in the ID Token. -func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) { +func (c *claimExtractor) GetClaim(claim string) (any, bool, error) { if claim == "" { return nil, false, nil } @@ -123,7 +123,7 @@ func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) { // GetClaimInto loads a claim and places it into the destination interface. // This will attempt to coerce the claim into the specified type. // If it cannot be coerced, an error may be returned. -func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) { +func (c *claimExtractor) GetClaimInto(claim string, dst any) (bool, error) { value, exists, err := c.GetClaim(claim) if err != nil { return false, fmt.Errorf("could not get claim %q: %v", claim, err) @@ -132,7 +132,7 @@ func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, erro return false, nil } if err := util.CoerceClaim(value, dst); err != nil { - return false, fmt.Errorf("could no coerce claim: %v", err) + return false, fmt.Errorf("could not coerce claim: %v", err) } return true, nil @@ -155,7 +155,7 @@ func parseJWT(p string) ([]byte, error) { // getClaimFrom gets a claim from a Json object. // It can accept either a single claim name or a json path. The claim is always evaluated first as a single claim name. // Paths with indexes are not supported. -func getClaimFrom(claim string, src *simplejson.Json) interface{} { +func getClaimFrom(claim string, src *simplejson.Json) any { if value, ok := src.CheckGet(claim); ok { return value.Interface() } diff --git a/pkg/providers/util/claim_extractor_test.go b/pkg/providers/util/claim_extractor_test.go index b8dbd34a..57e2e1dc 100644 --- a/pkg/providers/util/claim_extractor_test.go +++ b/pkg/providers/util/claim_extractor_test.go @@ -76,7 +76,7 @@ var _ = Describe("Claim Extractor Suite", func() { func(in newClaimExtractorTableInput) { _, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil) if in.expectedError != nil { - Expect(err).To(MatchError(in.expectedError)) + Expect(err).To(MatchError(in.expectedError.Error())) } else { Expect(err).ToNot(HaveOccurred()) } @@ -405,7 +405,7 @@ var _ = Describe("Claim Extractor Suite", func() { into: "", expectExists: false, expectedValue: "", - expectedError: errors.New("could no coerce claim: unknown type for destination: string"), + expectedError: errors.New("could not coerce claim: unknown type for destination: string"), }), Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{ testClaimExtractorOpts: testClaimExtractorOpts{ diff --git a/pkg/util/util.go b/pkg/util/util.go index 905af0ef..207316e7 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -198,7 +198,7 @@ func RemoveDuplicateStr(strSlice []string) []string { // CoerceClaim tries to convert the value into the destination interface type. // If it can convert the value, it will then store the value in the destination // interface. -func CoerceClaim(value, dst interface{}) error { +func CoerceClaim(value, dst any) error { switch d := dst.(type) { case *string: str, err := toString(value) @@ -222,15 +222,13 @@ func CoerceClaim(value, dst interface{}) error { // toStringSlice converts an interface (either a slice or single value) into // a slice of strings. -func toStringSlice(value interface{}) ([]string, error) { - var sliceValues []interface{} +func toStringSlice(value any) ([]string, error) { + var sliceValues []any switch v := value.(type) { - case []interface{}: + case []any: sliceValues = v - case interface{}: - sliceValues = []interface{}{v} default: - sliceValues = cast.ToSlice(value) + sliceValues = []any{v} } out := []string{} @@ -246,7 +244,7 @@ func toStringSlice(value interface{}) ([]string, error) { // toString coerces a value into a string. // If it is non-string, marshal it into JSON. -func toString(value interface{}) (string, error) { +func toString(value any) (string, error) { if str, err := cast.ToStringE(value); err == nil { return str, nil } diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 21afc0a5..d2a2eeca 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -8,6 +8,7 @@ import ( "reflect" "testing" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util/ptr" "github.com/stretchr/testify/assert" ) @@ -256,23 +257,11 @@ func TestGetCertPool(t *testing.T) { } } -func stringPointer(s string) *string { - return &s -} - -func stringSlicePointer(s []string) *[]string { - return &s -} - -func boolPointer(b bool) *bool { - return &b -} - type coerceClaimTableInput struct { name string - value interface{} - dst interface{} - expectedDst interface{} + value any + dst any + expectedDst any expectedError error } @@ -281,34 +270,34 @@ func TestCoerceClaim(t *testing.T) { { name: "coerces a string to a string", value: "some_string", - dst: stringPointer(""), - expectedDst: stringPointer("some_string"), + dst: ptr.To(""), + expectedDst: ptr.To("some_string"), }, { name: "coerces a slice to a string slice", - value: []interface{}{"a", "b"}, - dst: stringSlicePointer([]string{}), - expectedDst: stringSlicePointer([]string{"a", "b"}), + value: []any{"a", "b"}, + dst: ptr.To([]string{}), + expectedDst: ptr.To([]string{"a", "b"}), }, { name: "coerces a bool to a bool", value: true, - dst: boolPointer(false), - expectedDst: boolPointer(true), + dst: ptr.To(false), + expectedDst: ptr.To(true), }, { name: "coerces a string to a bool", value: "true", - dst: boolPointer(false), - expectedDst: boolPointer(true), + dst: ptr.To(false), + expectedDst: ptr.To(true), }, { name: "coerces a map to a string", - value: map[string]interface{}{ - "foo": []interface{}{"bar", "baz"}, + value: map[string]any{ + "foo": []any{"bar", "baz"}, }, - dst: stringPointer(""), - expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"), + dst: ptr.To(""), + expectedDst: ptr.To("{\"foo\":[\"bar\",\"baz\"]}"), }, } diff --git a/providers/provider_data.go b/providers/provider_data.go index 35f95487..8f9d1e36 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -45,12 +45,26 @@ type ProviderData struct { SupportedCodeChallengeMethods []string `json:"code_challenge_methods_supported,omitempty"` // Common OIDC options for any OIDC-based providers to consume - AllowUnverifiedEmail bool - UserClaim string - EmailClaim string - GroupsClaim string - Verifier internaloidc.IDTokenVerifier - AdditionalClaims []string `json:"additionalClaims,omitempty"` + AllowUnverifiedEmail bool + + // UserClaim is the claim to use for populating the SessionState.User field. Defaults to "sub" if not set. + UserClaim string + + // EmailClaim is the claim to use for populating the SessionState.Email field. + EmailClaim string + + // GroupsClaim is the claim to use for populating the SessionState.Groups field. + // If not set, groups will not be extracted from the ID Token or userinfo response. + GroupsClaim string + + // Verifier is the OIDC ID Token Verifier to be used by any OIDC-based providers to verify ID Tokens returned by the provider. + // It must be set up by the provider implementation and is not expected to be configured directly by users. + Verifier internaloidc.IDTokenVerifier + + // Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured. + AdditionalClaims []string `json:"additionalClaims,omitempty"` + + // SkipClaimsFromProfileURL indicates that claims should not be fetched from the ProfileURL, even if it is set. SkipClaimsFromProfileURL bool // Universal Group authorization data structure @@ -269,7 +283,6 @@ func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (* } } - // Extract additional claims if p.AdditionalClaims != nil { p.extractAdditionalClaims(extractor, ss) } @@ -309,10 +322,15 @@ func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.C func (p *ProviderData) extractAdditionalClaims(extractor util.ClaimExtractor, ss *sessions.SessionState) { if ss.AdditionalClaims == nil { - ss.AdditionalClaims = make(map[string]interface{}) + ss.AdditionalClaims = make(map[string]any) } for _, claim := range p.AdditionalClaims { - if value, exists, err := extractor.GetClaim(claim); err == nil && exists { + value, exists, err := extractor.GetClaim(claim) + if err != nil { + logger.Printf("error extracting additional claim %q: %v", claim, err) + continue + } + if exists { ss.AdditionalClaims[claim] = value } } diff --git a/providers/providers.go b/providers/providers.go index af8dd4e4..85c45ac5 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -84,8 +84,7 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, ClientSecret: providerConfig.ClientSecret, ClientSecretFile: providerConfig.ClientSecretFile, AuthRequestResponseMode: providerConfig.AuthRequestResponseMode, - // Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured. - AdditionalClaims: providerConfig.AdditionalClaims, + AdditionalClaims: providerConfig.AdditionalClaims, } needsVerifier, err := providerRequiresOIDCProviderVerifier(providerConfig.Type)