diff --git a/docs/docs/configuration/alpha_config.md b/docs/docs/configuration/alpha_config.md index b4a75582..b92b42f1 100644 --- a/docs/docs/configuration/alpha_config.md +++ b/docs/docs/configuration/alpha_config.md @@ -526,6 +526,7 @@ Provider holds all configuration for a single provider | `scope` | _string_ | Scope is the OAuth scope specification | | `allowedGroups` | _[]string_ | AllowedGroups is a list of restrict logins to members of this group | | `code_challenge_method` | _string_ | The code challenge method | +| `additionalClaims` | _[]string_ | Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured. | | `backendLogoutURL` | _string_ | URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session | ### ProviderType diff --git a/oauthproxy.go b/oauthproxy.go index 508084c8..82f265f8 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -721,15 +721,17 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { } userInfo := struct { - User string `json:"user"` - Email string `json:"email"` - Groups []string `json:"groups,omitempty"` - PreferredUsername string `json:"preferredUsername,omitempty"` + User string `json:"user"` + Email string `json:"email"` + Groups []string `json:"groups,omitempty"` + PreferredUsername string `json:"preferredUsername,omitempty"` + AdditionalClaims map[string]interface{} `json:"additionalClaims,omitempty"` }{ User: session.User, Email: session.Email, Groups: session.Groups, PreferredUsername: session.PreferredUsername, + AdditionalClaims: session.AdditionalClaims, } if err := json.NewEncoder(rw).Encode(userInfo); err != nil { diff --git a/oauthproxy_test.go b/oauthproxy_test.go index ccabdbbd..b1411e4c 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1032,6 +1032,20 @@ func TestUserInfoEndpointAccepted(t *testing.T) { }, expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"],\"preferredUsername\":\"john\"}\n", }, + { + name: "With Additional Claim", + session: &sessions.SessionState{ + User: "john.doe", + PreferredUsername: "john", + Email: "john.doe@example.com", + Groups: []string{"example", "groups"}, + AccessToken: "my_access_token", + AdditionalClaims: map[string]interface{}{ + "foo": "bar", + }, + }, + expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"],\"preferredUsername\":\"john\",\"additionalClaims\":{\"foo\":\"bar\"}}\n", + }, } for _, tc := range testCases { diff --git a/pkg/apis/options/providers.go b/pkg/apis/options/providers.go index 94bdb592..f734469e 100644 --- a/pkg/apis/options/providers.go +++ b/pkg/apis/options/providers.go @@ -134,6 +134,9 @@ type Provider struct { // The code challenge method CodeChallengeMethod string `yaml:"code_challenge_method,omitempty"` + // Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured. + AdditionalClaims []string `json:"additionalClaims,omitempty"` + // URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session BackendLogoutURL string `yaml:"backendLogoutURL"` } diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index a1f807ab..fef20aab 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -8,6 +8,7 @@ import ( "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" "github.com/pierrec/lz4/v4" "github.com/vmihailenco/msgpack/v5" ) @@ -28,6 +29,9 @@ type SessionState struct { Groups []string `msgpack:"g,omitempty"` PreferredUsername string `msgpack:"pu,omitempty"` + // Additional claims + AdditionalClaims map[string]interface{} `msgpack:"ac,omitempty"` + // Internal helpers, not serialized Clock func() time.Time `msgpack:"-"` // override for time.Now, for testing Lock Lock `msgpack:"-"` @@ -156,10 +160,20 @@ func (s *SessionState) GetClaim(claim string) []string { case "preferred_username": return []string{s.PreferredUsername} default: - return []string{} + return s.getAdditionalClaim(claim) } } +func (s *SessionState) getAdditionalClaim(claim string) []string { + if value, ok := s.AdditionalClaims[claim]; ok { + var result []string + if err := util.CoerceClaim(value, &result); err == nil { + return result + } + } + return []string{} +} + // CheckNonce compares the Nonce against a potential hash of it func (s *SessionState) CheckNonce(hashed string) bool { return encryption.CheckNonce(s.Nonce, hashed) diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 87b97614..1dc6d3ad 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -222,6 +222,23 @@ func TestEncodeAndDecodeSessionState(t *testing.T) { Nonce: []byte("abcdef1234567890abcdef1234567890"), Groups: []string{"group-a", "group-b"}, }, + "With additional claims": { + Email: "username@example.com", + User: "username", + PreferredUsername: "preferred.username", + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + ExpiresOn: &expires, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + Nonce: []byte("abcdef1234567890abcdef1234567890"), + Groups: []string{"group-a", "group-b"}, + AdditionalClaims: map[string]interface{}{ + "custom_claim_1": "value1", + "custom_claim_2": true, + "custom_claim_3": []interface{}{"item1", "item2"}, + }, + }, } for _, secretSize := range []int{16, 24, 32} { @@ -289,3 +306,50 @@ func compareSessionStates(t *testing.T, expected *SessionState, actual *SessionS act.ExpiresOn = nil assert.Equal(t, exp, act) } + +func TestGetClaim(t *testing.T) { + createdAt := time.Now() + expiresOn := createdAt.Add(1 * time.Hour) + + ss := &SessionState{ + CreatedAt: &createdAt, + ExpiresOn: &expiresOn, + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + Email: "user@example.com", + User: "user123", + Groups: []string{"group1", "group2"}, + PreferredUsername: "preferred_user", + AdditionalClaims: map[string]interface{}{ + "custom_claim_1": "value1", + "custom_claim_2": true, + "custom_claim_3": []string{"item1", "item2"}, + }, + } + + tests := []struct { + claim string + want []string + }{ + {"access_token", []string{"AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7"}}, + {"id_token", []string{"IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7"}}, + {"refresh_token", []string{"RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7"}}, + {"created_at", []string{createdAt.String()}}, + {"expires_on", []string{expiresOn.String()}}, + {"email", []string{"user@example.com"}}, + {"user", []string{"user123"}}, + {"groups", []string{"group1", "group2"}}, + {"preferred_username", []string{"preferred_user"}}, + {"custom_claim_1", []string{"value1"}}, + {"custom_claim_2", []string{"true"}}, + {"custom_claim_3", []string{"[\"item1\",\"item2\"]"}}, + } + + for _, tt := range tests { + t.Run(tt.claim, func(t *testing.T) { + gs := NewWithT(t) + gs.Expect(ss.GetClaim(tt.claim)).To(Equal(tt.want)) + }) + } +} diff --git a/pkg/providers/util/claim_extractor.go b/pkg/providers/util/claim_extractor.go index 9ab7a8c8..913d17ab 100644 --- a/pkg/providers/util/claim_extractor.go +++ b/pkg/providers/util/claim_extractor.go @@ -3,7 +3,6 @@ package util import ( "context" "encoding/base64" - "encoding/json" "fmt" "mime" "net/http" @@ -12,7 +11,7 @@ import ( "github.com/bitly/go-simplejson" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" - "github.com/spf13/cast" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" ) // ClaimExtractor is used to extract claim values from an ID Token, or, if not @@ -132,7 +131,7 @@ func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, erro if !exists { return false, nil } - if err := coerceClaim(value, dst); err != nil { + if err := util.CoerceClaim(value, dst); err != nil { return false, fmt.Errorf("could no coerce claim: %v", err) } @@ -163,66 +162,3 @@ func getClaimFrom(claim string, src *simplejson.Json) interface{} { claimParts := strings.Split(claim, ".") return src.GetPath(claimParts...).Interface() } - -// coerceClaim tries to convert the value into the destination interface type. -// If it can convert the value, it will then store the value in the destination -// interface. -func coerceClaim(value, dst interface{}) error { - switch d := dst.(type) { - case *string: - str, err := toString(value) - if err != nil { - return fmt.Errorf("could not convert value to string: %v", err) - } - *d = str - case *[]string: - strSlice, err := toStringSlice(value) - if err != nil { - return fmt.Errorf("could not convert value to string slice: %v", err) - } - *d = strSlice - case *bool: - *d = cast.ToBool(value) - default: - return fmt.Errorf("unknown type for destination: %T", dst) - } - return nil -} - -// toStringSlice converts an interface (either a slice or single value) into -// a slice of strings. -func toStringSlice(value interface{}) ([]string, error) { - var sliceValues []interface{} - switch v := value.(type) { - case []interface{}: - sliceValues = v - case interface{}: - sliceValues = []interface{}{v} - default: - sliceValues = cast.ToSlice(value) - } - - out := []string{} - for _, v := range sliceValues { - str, err := toString(v) - if err != nil { - return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err) - } - out = append(out, str) - } - return out, nil -} - -// toString coerces a value into a string. -// If it is non-string, marshal it into JSON. -func toString(value interface{}) (string, error) { - if str, err := cast.ToStringE(value); err == nil { - return str, nil - } - - jsonStr, err := json.Marshal(value) - if err != nil { - return "", err - } - return string(jsonStr), nil -} diff --git a/pkg/providers/util/claim_extractor_test.go b/pkg/providers/util/claim_extractor_test.go index 4ce4606f..b8dbd34a 100644 --- a/pkg/providers/util/claim_extractor_test.go +++ b/pkg/providers/util/claim_extractor_test.go @@ -451,53 +451,6 @@ var _ = Describe("Claim Extractor Suite", func() { }), ) - type coerceClaimTableInput struct { - value interface{} - dst interface{} - expectedDst interface{} - expectedError error - } - - DescribeTable("coerceClaim", - func(in coerceClaimTableInput) { - err := coerceClaim(in.value, in.dst) - if in.expectedError != nil { - Expect(err).To(MatchError(in.expectedError)) - return - } - - Expect(err).ToNot(HaveOccurred()) - Expect(in.dst).To(Equal(in.expectedDst)) - }, - Entry("coerces a string to a string", coerceClaimTableInput{ - value: "some_string", - dst: stringPointer(""), - expectedDst: stringPointer("some_string"), - }), - Entry("coerces a slice to a string slice", coerceClaimTableInput{ - value: []interface{}{"a", "b"}, - dst: stringSlicePointer([]string{}), - expectedDst: stringSlicePointer([]string{"a", "b"}), - }), - Entry("coerces a bool to a bool", coerceClaimTableInput{ - value: true, - dst: boolPointer(false), - expectedDst: boolPointer(true), - }), - Entry("coerces a string to a bool", coerceClaimTableInput{ - value: "true", - dst: boolPointer(false), - expectedDst: boolPointer(true), - }), - Entry("coerces a map to a string", coerceClaimTableInput{ - value: map[string]interface{}{ - "foo": []interface{}{"bar", "baz"}, - }, - dst: stringPointer(""), - expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"), - }), - ) - It("should extract claims from a JWT response", func() { jwtResponsePayload := `{ "user": "jwtUser", @@ -605,10 +558,6 @@ func stringSlicePointer(in []string) *[]string { return &in } -func boolPointer(in bool) *bool { - return &in -} - // ****************************** // Different profile URL handlers // ****************************** diff --git a/pkg/util/util.go b/pkg/util/util.go index 0f3d70ad..905af0ef 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -5,6 +5,7 @@ import ( "crypto/rsa" "crypto/x509" "crypto/x509/pkix" + "encoding/json" "fmt" "math/big" "net" @@ -12,6 +13,8 @@ import ( "os" "strings" "time" + + "github.com/spf13/cast" ) func GetCertPool(paths []string, useSystemPool bool) (*x509.CertPool, error) { @@ -191,3 +194,66 @@ func RemoveDuplicateStr(strSlice []string) []string { } return list } + +// CoerceClaim tries to convert the value into the destination interface type. +// If it can convert the value, it will then store the value in the destination +// interface. +func CoerceClaim(value, dst interface{}) error { + switch d := dst.(type) { + case *string: + str, err := toString(value) + if err != nil { + return fmt.Errorf("could not convert value to string: %v", err) + } + *d = str + case *[]string: + strSlice, err := toStringSlice(value) + if err != nil { + return fmt.Errorf("could not convert value to string slice: %v", err) + } + *d = strSlice + case *bool: + *d = cast.ToBool(value) + default: + return fmt.Errorf("unknown type for destination: %T", dst) + } + return nil +} + +// toStringSlice converts an interface (either a slice or single value) into +// a slice of strings. +func toStringSlice(value interface{}) ([]string, error) { + var sliceValues []interface{} + switch v := value.(type) { + case []interface{}: + sliceValues = v + case interface{}: + sliceValues = []interface{}{v} + default: + sliceValues = cast.ToSlice(value) + } + + out := []string{} + for _, v := range sliceValues { + str, err := toString(v) + if err != nil { + return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err) + } + out = append(out, str) + } + return out, nil +} + +// toString coerces a value into a string. +// If it is non-string, marshal it into JSON. +func toString(value interface{}) (string, error) { + if str, err := cast.ToStringE(value); err == nil { + return str, nil + } + + jsonStr, err := json.Marshal(value) + if err != nil { + return "", err + } + return string(jsonStr), nil +} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 167c3e59..21afc0a5 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -2,8 +2,10 @@ package util import ( "crypto/x509" + "encoding/json" "encoding/pem" "os" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -253,3 +255,82 @@ func TestGetCertPool(t *testing.T) { assert.Error(t, err3) } } + +func stringPointer(s string) *string { + return &s +} + +func stringSlicePointer(s []string) *[]string { + return &s +} + +func boolPointer(b bool) *bool { + return &b +} + +type coerceClaimTableInput struct { + name string + value interface{} + dst interface{} + expectedDst interface{} + expectedError error +} + +func TestCoerceClaim(t *testing.T) { + tests := []coerceClaimTableInput{ + { + name: "coerces a string to a string", + value: "some_string", + dst: stringPointer(""), + expectedDst: stringPointer("some_string"), + }, + { + name: "coerces a slice to a string slice", + value: []interface{}{"a", "b"}, + dst: stringSlicePointer([]string{}), + expectedDst: stringSlicePointer([]string{"a", "b"}), + }, + { + name: "coerces a bool to a bool", + value: true, + dst: boolPointer(false), + expectedDst: boolPointer(true), + }, + { + name: "coerces a string to a bool", + value: "true", + dst: boolPointer(false), + expectedDst: boolPointer(true), + }, + { + name: "coerces a map to a string", + value: map[string]interface{}{ + "foo": []interface{}{"bar", "baz"}, + }, + dst: stringPointer(""), + expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CoerceClaim(tt.value, tt.dst) + if tt.expectedError != nil { + if err == nil || err.Error() != tt.expectedError.Error() { + t.Errorf("expected error %v, got %v", tt.expectedError, err) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(tt.dst, tt.expectedDst) { + gotJSON, _ := json.Marshal(tt.dst) + wantJSON, _ := json.Marshal(tt.expectedDst) + t.Errorf("expected dst to be %+v, got %+v", string(wantJSON), string(gotJSON)) + } + }) + } +} diff --git a/providers/provider_data.go b/providers/provider_data.go index 95de5c50..35f95487 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -50,6 +50,7 @@ type ProviderData struct { EmailClaim string GroupsClaim string Verifier internaloidc.IDTokenVerifier + AdditionalClaims []string `json:"additionalClaims,omitempty"` SkipClaimsFromProfileURL bool // Universal Group authorization data structure @@ -268,6 +269,11 @@ func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (* } } + // Extract additional claims + if p.AdditionalClaims != nil { + p.extractAdditionalClaims(extractor, ss) + } + // `email_verified` must be present and explicitly set to `false` to be // considered unverified. verifyEmail := (p.EmailClaim == options.OIDCEmailClaim) && !p.AllowUnverifiedEmail @@ -301,6 +307,17 @@ func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.C return extractor, nil } +func (p *ProviderData) extractAdditionalClaims(extractor util.ClaimExtractor, ss *sessions.SessionState) { + if ss.AdditionalClaims == nil { + ss.AdditionalClaims = make(map[string]interface{}) + } + for _, claim := range p.AdditionalClaims { + if value, exists, err := extractor.GetClaim(claim); err == nil && exists { + ss.AdditionalClaims[claim] = value + } + } +} + // checkNonce compares the session's nonce with the IDToken's nonce claim func (p *ProviderData) checkNonce(s *sessions.SessionState) error { extractor, err := p.getClaimExtractor(s.IDToken, "") diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index 044a77b1..9801d20c 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -237,6 +237,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { ExpectedError error ExpectedSession *sessions.SessionState ExpectProfileURLCalled bool + AdditionalClaims []string }{ "Standard": { IDToken: defaultIDToken, @@ -417,6 +418,27 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { SkipClaimsFromProfileURL: true, ExpectedSession: &sessions.SessionState{}, }, + "Additional claims": { + IDToken: defaultIDToken, + AdditionalClaims: []string{"phone_number", "picture"}, + ExpectedSession: &sessions.SessionState{ + PreferredUsername: "Jane Dobbs", + AdditionalClaims: map[string]interface{}{ + "phone_number": "+4798765432", + "picture": "http://mugbook.com/janed/me.jpg", + }, + }, + }, + "Additional claims with missing claim": { + IDToken: defaultIDToken, + AdditionalClaims: []string{"phone_number", "picture1"}, + ExpectedSession: &sessions.SessionState{ + PreferredUsername: "Jane Dobbs", + AdditionalClaims: map[string]interface{}{ + "phone_number": "+4798765432", + }, + }, + }, } for testName, tc := range testCases { t.Run(testName, func(t *testing.T) { @@ -453,6 +475,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { provider.EmailClaim = tc.EmailClaim provider.GroupsClaim = tc.GroupsClaim provider.SkipClaimsFromProfileURL = tc.SkipClaimsFromProfileURL + provider.AdditionalClaims = tc.AdditionalClaims rawIDToken, err := newSignedTestIDToken(tc.IDToken) g.Expect(err).ToNot(HaveOccurred()) diff --git a/providers/providers.go b/providers/providers.go index 6af51ecf..af8dd4e4 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -84,6 +84,8 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, ClientSecret: providerConfig.ClientSecret, ClientSecretFile: providerConfig.ClientSecretFile, AuthRequestResponseMode: providerConfig.AuthRequestResponseMode, + // Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured. + AdditionalClaims: providerConfig.AdditionalClaims, } needsVerifier, err := providerRequiresOIDCProviderVerifier(providerConfig.Type)