refactor: extract coerceClaim logic into util

Signed-off-by: afsu <suaf2020@163.com>
This commit is contained in:
afsu 2025-04-29 16:22:57 +08:00
parent 230de6253a
commit 6b3f1c60d0
6 changed files with 155 additions and 139 deletions

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
"github.com/pierrec/lz4/v4"
"github.com/vmihailenco/msgpack/v5"
)
@ -165,23 +166,9 @@ func (s *SessionState) GetClaim(claim string) []string {
func (s *SessionState) getAdditionalClaim(claim string) []string {
if value, ok := s.AdditionalClaims[claim]; ok {
switch v := value.(type) {
case string:
return []string{v}
case []string:
return v
case []interface{}:
result := make([]string, len(v))
for i, item := range v {
if str, ok := item.(string); ok {
result[i] = str
} else {
result[i] = fmt.Sprintf("%v", item)
}
}
var result []string
if err := util.CoerceClaim(value, &result); err == nil {
return result
default:
return []string{fmt.Sprintf("%v", value)}
}
}
return []string{}

View File

@ -236,8 +236,7 @@ func TestEncodeAndDecodeSessionState(t *testing.T) {
AdditionalClaims: map[string]interface{}{
"custom_claim_1": "value1",
"custom_claim_2": true,
"custom_claim_3": int8(1),
"custom_claim_4": []interface{}{"item1", "item2"},
"custom_claim_3": []interface{}{"item1", "item2"},
},
},
}
@ -325,8 +324,7 @@ func TestGetClaim(t *testing.T) {
AdditionalClaims: map[string]interface{}{
"custom_claim_1": "value1",
"custom_claim_2": true,
"custom_claim_3": 1,
"custom_claim_4": []string{"item1", "item2"},
"custom_claim_3": []string{"item1", "item2"},
},
}
@ -345,8 +343,7 @@ func TestGetClaim(t *testing.T) {
{"preferred_username", []string{"preferred_user"}},
{"custom_claim_1", []string{"value1"}},
{"custom_claim_2", []string{"true"}},
{"custom_claim_3", []string{"1"}},
{"custom_claim_4", []string{"item1", "item2"}},
{"custom_claim_3", []string{"[\"item1\",\"item2\"]"}},
}
for _, tt := range tests {

View File

@ -3,7 +3,6 @@ package util
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"mime"
"net/http"
@ -12,7 +11,7 @@ import (
"github.com/bitly/go-simplejson"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"github.com/spf13/cast"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
)
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
@ -132,7 +131,7 @@ func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, erro
if !exists {
return false, nil
}
if err := coerceClaim(value, dst); err != nil {
if err := util.CoerceClaim(value, dst); err != nil {
return false, fmt.Errorf("could no coerce claim: %v", err)
}
@ -163,66 +162,3 @@ func getClaimFrom(claim string, src *simplejson.Json) interface{} {
claimParts := strings.Split(claim, ".")
return src.GetPath(claimParts...).Interface()
}
// coerceClaim tries to convert the value into the destination interface type.
// If it can convert the value, it will then store the value in the destination
// interface.
func coerceClaim(value, dst interface{}) error {
switch d := dst.(type) {
case *string:
str, err := toString(value)
if err != nil {
return fmt.Errorf("could not convert value to string: %v", err)
}
*d = str
case *[]string:
strSlice, err := toStringSlice(value)
if err != nil {
return fmt.Errorf("could not convert value to string slice: %v", err)
}
*d = strSlice
case *bool:
*d = cast.ToBool(value)
default:
return fmt.Errorf("unknown type for destination: %T", dst)
}
return nil
}
// toStringSlice converts an interface (either a slice or single value) into
// a slice of strings.
func toStringSlice(value interface{}) ([]string, error) {
var sliceValues []interface{}
switch v := value.(type) {
case []interface{}:
sliceValues = v
case interface{}:
sliceValues = []interface{}{v}
default:
sliceValues = cast.ToSlice(value)
}
out := []string{}
for _, v := range sliceValues {
str, err := toString(v)
if err != nil {
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
}
out = append(out, str)
}
return out, nil
}
// toString coerces a value into a string.
// If it is non-string, marshal it into JSON.
func toString(value interface{}) (string, error) {
if str, err := cast.ToStringE(value); err == nil {
return str, nil
}
jsonStr, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(jsonStr), nil
}

View File

@ -451,53 +451,6 @@ var _ = Describe("Claim Extractor Suite", func() {
}),
)
type coerceClaimTableInput struct {
value interface{}
dst interface{}
expectedDst interface{}
expectedError error
}
DescribeTable("coerceClaim",
func(in coerceClaimTableInput) {
err := coerceClaim(in.value, in.dst)
if in.expectedError != nil {
Expect(err).To(MatchError(in.expectedError))
return
}
Expect(err).ToNot(HaveOccurred())
Expect(in.dst).To(Equal(in.expectedDst))
},
Entry("coerces a string to a string", coerceClaimTableInput{
value: "some_string",
dst: stringPointer(""),
expectedDst: stringPointer("some_string"),
}),
Entry("coerces a slice to a string slice", coerceClaimTableInput{
value: []interface{}{"a", "b"},
dst: stringSlicePointer([]string{}),
expectedDst: stringSlicePointer([]string{"a", "b"}),
}),
Entry("coerces a bool to a bool", coerceClaimTableInput{
value: true,
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a string to a bool", coerceClaimTableInput{
value: "true",
dst: boolPointer(false),
expectedDst: boolPointer(true),
}),
Entry("coerces a map to a string", coerceClaimTableInput{
value: map[string]interface{}{
"foo": []interface{}{"bar", "baz"},
},
dst: stringPointer(""),
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
}),
)
It("should extract claims from a JWT response", func() {
jwtResponsePayload := `{
"user": "jwtUser",
@ -605,10 +558,6 @@ func stringSlicePointer(in []string) *[]string {
return &in
}
func boolPointer(in bool) *bool {
return &in
}
// ******************************
// Different profile URL handlers
// ******************************

View File

@ -5,6 +5,7 @@ import (
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"math/big"
"net"
@ -12,6 +13,8 @@ import (
"os"
"strings"
"time"
"github.com/spf13/cast"
)
func GetCertPool(paths []string, useSystemPool bool) (*x509.CertPool, error) {
@ -191,3 +194,66 @@ func RemoveDuplicateStr(strSlice []string) []string {
}
return list
}
// CoerceClaim tries to convert the value into the destination interface type.
// If it can convert the value, it will then store the value in the destination
// interface.
func CoerceClaim(value, dst interface{}) error {
switch d := dst.(type) {
case *string:
str, err := toString(value)
if err != nil {
return fmt.Errorf("could not convert value to string: %v", err)
}
*d = str
case *[]string:
strSlice, err := toStringSlice(value)
if err != nil {
return fmt.Errorf("could not convert value to string slice: %v", err)
}
*d = strSlice
case *bool:
*d = cast.ToBool(value)
default:
return fmt.Errorf("unknown type for destination: %T", dst)
}
return nil
}
// toStringSlice converts an interface (either a slice or single value) into
// a slice of strings.
func toStringSlice(value interface{}) ([]string, error) {
var sliceValues []interface{}
switch v := value.(type) {
case []interface{}:
sliceValues = v
case interface{}:
sliceValues = []interface{}{v}
default:
sliceValues = cast.ToSlice(value)
}
out := []string{}
for _, v := range sliceValues {
str, err := toString(v)
if err != nil {
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
}
out = append(out, str)
}
return out, nil
}
// toString coerces a value into a string.
// If it is non-string, marshal it into JSON.
func toString(value interface{}) (string, error) {
if str, err := cast.ToStringE(value); err == nil {
return str, nil
}
jsonStr, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(jsonStr), nil
}

View File

@ -2,8 +2,10 @@ package util
import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"os"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
@ -253,3 +255,82 @@ func TestGetCertPool(t *testing.T) {
assert.Error(t, err3)
}
}
func stringPointer(s string) *string {
return &s
}
func stringSlicePointer(s []string) *[]string {
return &s
}
func boolPointer(b bool) *bool {
return &b
}
type coerceClaimTableInput struct {
name string
value interface{}
dst interface{}
expectedDst interface{}
expectedError error
}
func TestCoerceClaim(t *testing.T) {
tests := []coerceClaimTableInput{
{
name: "coerces a string to a string",
value: "some_string",
dst: stringPointer(""),
expectedDst: stringPointer("some_string"),
},
{
name: "coerces a slice to a string slice",
value: []interface{}{"a", "b"},
dst: stringSlicePointer([]string{}),
expectedDst: stringSlicePointer([]string{"a", "b"}),
},
{
name: "coerces a bool to a bool",
value: true,
dst: boolPointer(false),
expectedDst: boolPointer(true),
},
{
name: "coerces a string to a bool",
value: "true",
dst: boolPointer(false),
expectedDst: boolPointer(true),
},
{
name: "coerces a map to a string",
value: map[string]interface{}{
"foo": []interface{}{"bar", "baz"},
},
dst: stringPointer(""),
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CoerceClaim(tt.value, tt.dst)
if tt.expectedError != nil {
if err == nil || err.Error() != tt.expectedError.Error() {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !reflect.DeepEqual(tt.dst, tt.expectedDst) {
gotJSON, _ := json.Marshal(tt.dst)
wantJSON, _ := json.Marshal(tt.expectedDst)
t.Errorf("expected dst to be %+v, got %+v", string(wantJSON), string(gotJSON))
}
})
}
}