This commit is contained in:
af su 2026-01-28 10:59:04 +08:00 committed by GitHub
commit c27c1707c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 294 additions and 122 deletions

View File

@ -526,6 +526,7 @@ Provider holds all configuration for a single provider
| `scope` | _string_ | Scope is the OAuth scope specification |
| `allowedGroups` | _[]string_ | AllowedGroups is a list of restrict logins to members of this group |
| `code_challenge_method` | _string_ | The code challenge method |
| `additionalClaims` | _[]string_ | Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured. |
| `backendLogoutURL` | _string_ | URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session |
### ProviderType

View File

@ -721,15 +721,17 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
}
userInfo := struct {
User string `json:"user"`
Email string `json:"email"`
Groups []string `json:"groups,omitempty"`
PreferredUsername string `json:"preferredUsername,omitempty"`
User string `json:"user"`
Email string `json:"email"`
Groups []string `json:"groups,omitempty"`
PreferredUsername string `json:"preferredUsername,omitempty"`
AdditionalClaims map[string]interface{} `json:"additionalClaims,omitempty"`
}{
User: session.User,
Email: session.Email,
Groups: session.Groups,
PreferredUsername: session.PreferredUsername,
AdditionalClaims: session.AdditionalClaims,
}
if err := json.NewEncoder(rw).Encode(userInfo); err != nil {

View File

@ -1032,6 +1032,20 @@ func TestUserInfoEndpointAccepted(t *testing.T) {
},
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"],\"preferredUsername\":\"john\"}\n",
},
{
name: "With Additional Claim",
session: &sessions.SessionState{
User: "john.doe",
PreferredUsername: "john",
Email: "john.doe@example.com",
Groups: []string{"example", "groups"},
AccessToken: "my_access_token",
AdditionalClaims: map[string]interface{}{
"foo": "bar",
},
},
expectedResponse: "{\"user\":\"john.doe\",\"email\":\"john.doe@example.com\",\"groups\":[\"example\",\"groups\"],\"preferredUsername\":\"john\",\"additionalClaims\":{\"foo\":\"bar\"}}\n",
},
}
for _, tc := range testCases {

View File

@ -134,6 +134,9 @@ type Provider struct {
// The code challenge method
CodeChallengeMethod string `yaml:"code_challenge_method,omitempty"`
// Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured.
AdditionalClaims []string `json:"additionalClaims,omitempty"`
// URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session
BackendLogoutURL string `yaml:"backendLogoutURL"`
}

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
"github.com/pierrec/lz4/v4"
"github.com/vmihailenco/msgpack/v5"
)
@ -28,6 +29,9 @@ type SessionState struct {
Groups []string `msgpack:"g,omitempty"`
PreferredUsername string `msgpack:"pu,omitempty"`
// Additional claims
AdditionalClaims map[string]interface{} `msgpack:"ac,omitempty"`
// Internal helpers, not serialized
Clock func() time.Time `msgpack:"-"` // override for time.Now, for testing
Lock Lock `msgpack:"-"`
@ -156,10 +160,20 @@ func (s *SessionState) GetClaim(claim string) []string {
case "preferred_username":
return []string{s.PreferredUsername}
default:
return []string{}
return s.getAdditionalClaim(claim)
}
}
func (s *SessionState) getAdditionalClaim(claim string) []string {
if value, ok := s.AdditionalClaims[claim]; ok {
var result []string
if err := util.CoerceClaim(value, &result); err == nil {
return result
}
}
return []string{}
}
// CheckNonce compares the Nonce against a potential hash of it
func (s *SessionState) CheckNonce(hashed string) bool {
return encryption.CheckNonce(s.Nonce, hashed)

View File

@ -222,6 +222,23 @@ func TestEncodeAndDecodeSessionState(t *testing.T) {
Nonce: []byte("abcdef1234567890abcdef1234567890"),
Groups: []string{"group-a", "group-b"},
},
"With additional claims": {
Email: "username@example.com",
User: "username",
PreferredUsername: "preferred.username",
AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
Nonce: []byte("abcdef1234567890abcdef1234567890"),
Groups: []string{"group-a", "group-b"},
AdditionalClaims: map[string]interface{}{
"custom_claim_1": "value1",
"custom_claim_2": true,
"custom_claim_3": []interface{}{"item1", "item2"},
},
},
}
for _, secretSize := range []int{16, 24, 32} {
@ -289,3 +306,50 @@ func compareSessionStates(t *testing.T, expected *SessionState, actual *SessionS
act.ExpiresOn = nil
assert.Equal(t, exp, act)
}
func TestGetClaim(t *testing.T) {
createdAt := time.Now()
expiresOn := createdAt.Add(1 * time.Hour)
ss := &SessionState{
CreatedAt: &createdAt,
ExpiresOn: &expiresOn,
AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
Email: "user@example.com",
User: "user123",
Groups: []string{"group1", "group2"},
PreferredUsername: "preferred_user",
AdditionalClaims: map[string]interface{}{
"custom_claim_1": "value1",
"custom_claim_2": true,
"custom_claim_3": []string{"item1", "item2"},
},
}
tests := []struct {
claim string
want []string
}{
{"access_token", []string{"AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7"}},
{"id_token", []string{"IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7"}},
{"refresh_token", []string{"RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7"}},
{"created_at", []string{createdAt.String()}},
{"expires_on", []string{expiresOn.String()}},
{"email", []string{"user@example.com"}},
{"user", []string{"user123"}},
{"groups", []string{"group1", "group2"}},
{"preferred_username", []string{"preferred_user"}},
{"custom_claim_1", []string{"value1"}},
{"custom_claim_2", []string{"true"}},
{"custom_claim_3", []string{"[\"item1\",\"item2\"]"}},
}
for _, tt := range tests {
t.Run(tt.claim, func(t *testing.T) {
gs := NewWithT(t)
gs.Expect(ss.GetClaim(tt.claim)).To(Equal(tt.want))
})
}
}

View File

@ -3,7 +3,6 @@ package util
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"mime"
"net/http"
@ -12,7 +11,7 @@ import (
"github.com/bitly/go-simplejson"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"github.com/spf13/cast"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
)
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
@ -132,7 +131,7 @@ func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, erro
if !exists {
return false, nil
}
if err := coerceClaim(value, dst); err != nil {
if err := util.CoerceClaim(value, dst); err != nil {
return false, fmt.Errorf("could no coerce claim: %v", err)
}
@ -163,66 +162,3 @@ func getClaimFrom(claim string, src *simplejson.Json) interface{} {
claimParts := strings.Split(claim, ".")
return src.GetPath(claimParts...).Interface()
}
// coerceClaim tries to convert the value into the destination interface type.
// If it can convert the value, it will then store the value in the destination
// interface.
func coerceClaim(value, dst interface{}) error {
switch d := dst.(type) {
case *string:
str, err := toString(value)
if err != nil {
return fmt.Errorf("could not convert value to string: %v", err)
}
*d = str
case *[]string:
strSlice, err := toStringSlice(value)
if err != nil {
return fmt.Errorf("could not convert value to string slice: %v", err)
}
*d = strSlice
case *bool:
*d = cast.ToBool(value)
default:
return fmt.Errorf("unknown type for destination: %T", dst)
}
return nil
}
// toStringSlice converts an interface (either a slice or single value) into
// a slice of strings.
func toStringSlice(value interface{}) ([]string, error) {
var sliceValues []interface{}
switch v := value.(type) {
case []interface{}:
sliceValues = v
case interface{}:
sliceValues = []interface{}{v}
default:
sliceValues = cast.ToSlice(value)
}
out := []string{}
for _, v := range sliceValues {
str, err := toString(v)
if err != nil {
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
}
out = append(out, str)
}
return out, nil
}
// toString coerces a value into a string.
// If it is non-string, marshal it into JSON.
func toString(value interface{}) (string, error) {
if str, err := cast.ToStringE(value); err == nil {
return str, nil
}
jsonStr, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(jsonStr), nil
}

View File

@ -451,53 +451,6 @@ var _ = Describe("Claim Extractor Suite", func() {
}),
)
type coerceClaimTableInput struct {
value interface{}
dst interface{}
expectedDst interface{}
expectedError error
}
DescribeTable("coerceClaim",
func(in coerceClaimTableInput) {
err := coerceClaim(in.value, in.dst)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
Expect(in.dst).To(Equal(in.expectedDst))
},
Entry("coerces a string to a string", coerceClaimTableInput{
value: "some_string",
dst: stringPointer(""),
expectedDst: stringPointer("some_string"),
}),
Entry("coerces a slice to a string slice", coerceClaimTableInput{
value: []interface{}{"a", "b"},
dst: stringSlicePointer([]string{}),
expectedDst: stringSlicePointer([]string{"a", "b"}),
}),
Entry("coerces a bool to a bool", coerceClaimTableInput{
value: true,
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a string to a bool", coerceClaimTableInput{
value: "true",
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a map to a string", coerceClaimTableInput{
value: map[string]interface{}{
"foo": []interface{}{"bar", "baz"},
},
dst: stringPointer(""),
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
}),
)
It("should extract claims from a JWT response", func() {
jwtResponsePayload := `{
"user": "jwtUser",
@ -605,10 +558,6 @@ func stringSlicePointer(in []string) *[]string {
return &in
}
func boolPointer(in bool) *bool {
return &in
}
// ******************************
// Different profile URL handlers
// ******************************

View File

@ -5,6 +5,7 @@ import (
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"math/big"
"net"
@ -12,6 +13,8 @@ import (
"os"
"strings"
"time"
"github.com/spf13/cast"
)
func GetCertPool(paths []string, useSystemPool bool) (*x509.CertPool, error) {
@ -191,3 +194,66 @@ func RemoveDuplicateStr(strSlice []string) []string {
}
return list
}
// CoerceClaim tries to convert the value into the destination interface type.
// If it can convert the value, it will then store the value in the destination
// interface.
func CoerceClaim(value, dst interface{}) error {
switch d := dst.(type) {
case *string:
str, err := toString(value)
if err != nil {
return fmt.Errorf("could not convert value to string: %v", err)
}
*d = str
case *[]string:
strSlice, err := toStringSlice(value)
if err != nil {
return fmt.Errorf("could not convert value to string slice: %v", err)
}
*d = strSlice
case *bool:
*d = cast.ToBool(value)
default:
return fmt.Errorf("unknown type for destination: %T", dst)
}
return nil
}
// toStringSlice converts an interface (either a slice or single value) into
// a slice of strings.
func toStringSlice(value interface{}) ([]string, error) {
var sliceValues []interface{}
switch v := value.(type) {
case []interface{}:
sliceValues = v
case interface{}:
sliceValues = []interface{}{v}
default:
sliceValues = cast.ToSlice(value)
}
out := []string{}
for _, v := range sliceValues {
str, err := toString(v)
if err != nil {
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
}
out = append(out, str)
}
return out, nil
}
// toString coerces a value into a string.
// If it is non-string, marshal it into JSON.
func toString(value interface{}) (string, error) {
if str, err := cast.ToStringE(value); err == nil {
return str, nil
}
jsonStr, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(jsonStr), nil
}

View File

@ -2,8 +2,10 @@ package util
import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"os"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
@ -253,3 +255,82 @@ func TestGetCertPool(t *testing.T) {
assert.Error(t, err3)
}
}
func stringPointer(s string) *string {
return &s
}
func stringSlicePointer(s []string) *[]string {
return &s
}
func boolPointer(b bool) *bool {
return &b
}
type coerceClaimTableInput struct {
name string
value interface{}
dst interface{}
expectedDst interface{}
expectedError error
}
func TestCoerceClaim(t *testing.T) {
tests := []coerceClaimTableInput{
{
name: "coerces a string to a string",
value: "some_string",
dst: stringPointer(""),
expectedDst: stringPointer("some_string"),
},
{
name: "coerces a slice to a string slice",
value: []interface{}{"a", "b"},
dst: stringSlicePointer([]string{}),
expectedDst: stringSlicePointer([]string{"a", "b"}),
},
{
name: "coerces a bool to a bool",
value: true,
dst: boolPointer(false),
expectedDst: boolPointer(true),
},
{
name: "coerces a string to a bool",
value: "true",
dst: boolPointer(false),
expectedDst: boolPointer(true),
},
{
name: "coerces a map to a string",
value: map[string]interface{}{
"foo": []interface{}{"bar", "baz"},
},
dst: stringPointer(""),
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CoerceClaim(tt.value, tt.dst)
if tt.expectedError != nil {
if err == nil || err.Error() != tt.expectedError.Error() {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !reflect.DeepEqual(tt.dst, tt.expectedDst) {
gotJSON, _ := json.Marshal(tt.dst)
wantJSON, _ := json.Marshal(tt.expectedDst)
t.Errorf("expected dst to be %+v, got %+v", string(wantJSON), string(gotJSON))
}
})
}
}

View File

@ -50,6 +50,7 @@ type ProviderData struct {
EmailClaim string
GroupsClaim string
Verifier internaloidc.IDTokenVerifier
AdditionalClaims []string `json:"additionalClaims,omitempty"`
SkipClaimsFromProfileURL bool
// Universal Group authorization data structure
@ -268,6 +269,11 @@ func (p *ProviderData) buildSessionFromClaims(rawIDToken, accessToken string) (*
}
}
// Extract additional claims
if p.AdditionalClaims != nil {
p.extractAdditionalClaims(extractor, ss)
}
// `email_verified` must be present and explicitly set to `false` to be
// considered unverified.
verifyEmail := (p.EmailClaim == options.OIDCEmailClaim) && !p.AllowUnverifiedEmail
@ -301,6 +307,17 @@ func (p *ProviderData) getClaimExtractor(rawIDToken, accessToken string) (util.C
return extractor, nil
}
func (p *ProviderData) extractAdditionalClaims(extractor util.ClaimExtractor, ss *sessions.SessionState) {
if ss.AdditionalClaims == nil {
ss.AdditionalClaims = make(map[string]interface{})
}
for _, claim := range p.AdditionalClaims {
if value, exists, err := extractor.GetClaim(claim); err == nil && exists {
ss.AdditionalClaims[claim] = value
}
}
}
// checkNonce compares the session's nonce with the IDToken's nonce claim
func (p *ProviderData) checkNonce(s *sessions.SessionState) error {
extractor, err := p.getClaimExtractor(s.IDToken, "")

View File

@ -237,6 +237,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
ExpectedError error
ExpectedSession *sessions.SessionState
ExpectProfileURLCalled bool
AdditionalClaims []string
}{
"Standard": {
IDToken: defaultIDToken,
@ -417,6 +418,27 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
SkipClaimsFromProfileURL: true,
ExpectedSession: &sessions.SessionState{},
},
"Additional claims": {
IDToken: defaultIDToken,
AdditionalClaims: []string{"phone_number", "picture"},
ExpectedSession: &sessions.SessionState{
PreferredUsername: "Jane Dobbs",
AdditionalClaims: map[string]interface{}{
"phone_number": "+4798765432",
"picture": "http://mugbook.com/janed/me.jpg",
},
},
},
"Additional claims with missing claim": {
IDToken: defaultIDToken,
AdditionalClaims: []string{"phone_number", "picture1"},
ExpectedSession: &sessions.SessionState{
PreferredUsername: "Jane Dobbs",
AdditionalClaims: map[string]interface{}{
"phone_number": "+4798765432",
},
},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
@ -453,6 +475,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
provider.EmailClaim = tc.EmailClaim
provider.GroupsClaim = tc.GroupsClaim
provider.SkipClaimsFromProfileURL = tc.SkipClaimsFromProfileURL
provider.AdditionalClaims = tc.AdditionalClaims
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
g.Expect(err).ToNot(HaveOccurred())

View File

@ -84,6 +84,8 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData,
ClientSecret: providerConfig.ClientSecret,
ClientSecretFile: providerConfig.ClientSecretFile,
AuthRequestResponseMode: providerConfig.AuthRequestResponseMode,
// Additional claims to be obtained from the upstream IDP, either from the id_token or from the userinfo endpoint if configured.
AdditionalClaims: providerConfig.AdditionalClaims,
}
needsVerifier, err := providerRequiresOIDCProviderVerifier(providerConfig.Type)