From 6b3f1c60d0bb4f68272b1f814fa0415cf1e5d406 Mon Sep 17 00:00:00 2001 From: afsu Date: Tue, 29 Apr 2025 16:22:57 +0800 Subject: [PATCH] refactor: extract coerceClaim logic into util Signed-off-by: afsu --- pkg/apis/sessions/session_state.go | 19 +---- pkg/apis/sessions/session_state_test.go | 9 +-- pkg/providers/util/claim_extractor.go | 68 +----------------- pkg/providers/util/claim_extractor_test.go | 51 -------------- pkg/util/util.go | 66 ++++++++++++++++++ pkg/util/util_test.go | 81 ++++++++++++++++++++++ 6 files changed, 155 insertions(+), 139 deletions(-) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index dbeb7edb..fef20aab 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -8,6 +8,7 @@ import ( "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" "github.com/pierrec/lz4/v4" "github.com/vmihailenco/msgpack/v5" ) @@ -165,23 +166,9 @@ func (s *SessionState) GetClaim(claim string) []string { func (s *SessionState) getAdditionalClaim(claim string) []string { if value, ok := s.AdditionalClaims[claim]; ok { - switch v := value.(type) { - case string: - return []string{v} - case []string: - return v - case []interface{}: - result := make([]string, len(v)) - for i, item := range v { - if str, ok := item.(string); ok { - result[i] = str - } else { - result[i] = fmt.Sprintf("%v", item) - } - } + var result []string + if err := util.CoerceClaim(value, &result); err == nil { return result - default: - return []string{fmt.Sprintf("%v", value)} } } return []string{} diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 3837a696..1dc6d3ad 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -236,8 +236,7 @@ func TestEncodeAndDecodeSessionState(t *testing.T) { AdditionalClaims: map[string]interface{}{ "custom_claim_1": "value1", "custom_claim_2": true, - "custom_claim_3": int8(1), - "custom_claim_4": []interface{}{"item1", "item2"}, + "custom_claim_3": []interface{}{"item1", "item2"}, }, }, } @@ -325,8 +324,7 @@ func TestGetClaim(t *testing.T) { AdditionalClaims: map[string]interface{}{ "custom_claim_1": "value1", "custom_claim_2": true, - "custom_claim_3": 1, - "custom_claim_4": []string{"item1", "item2"}, + "custom_claim_3": []string{"item1", "item2"}, }, } @@ -345,8 +343,7 @@ func TestGetClaim(t *testing.T) { {"preferred_username", []string{"preferred_user"}}, {"custom_claim_1", []string{"value1"}}, {"custom_claim_2", []string{"true"}}, - {"custom_claim_3", []string{"1"}}, - {"custom_claim_4", []string{"item1", "item2"}}, + {"custom_claim_3", []string{"[\"item1\",\"item2\"]"}}, } for _, tt := range tests { diff --git a/pkg/providers/util/claim_extractor.go b/pkg/providers/util/claim_extractor.go index 9ab7a8c8..913d17ab 100644 --- a/pkg/providers/util/claim_extractor.go +++ b/pkg/providers/util/claim_extractor.go @@ -3,7 +3,6 @@ package util import ( "context" "encoding/base64" - "encoding/json" "fmt" "mime" "net/http" @@ -12,7 +11,7 @@ import ( "github.com/bitly/go-simplejson" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" - "github.com/spf13/cast" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" ) // ClaimExtractor is used to extract claim values from an ID Token, or, if not @@ -132,7 +131,7 @@ func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, erro if !exists { return false, nil } - if err := coerceClaim(value, dst); err != nil { + if err := util.CoerceClaim(value, dst); err != nil { return false, fmt.Errorf("could no coerce claim: %v", err) } @@ -163,66 +162,3 @@ func getClaimFrom(claim string, src *simplejson.Json) interface{} { claimParts := strings.Split(claim, ".") return src.GetPath(claimParts...).Interface() } - -// coerceClaim tries to convert the value into the destination interface type. -// If it can convert the value, it will then store the value in the destination -// interface. -func coerceClaim(value, dst interface{}) error { - switch d := dst.(type) { - case *string: - str, err := toString(value) - if err != nil { - return fmt.Errorf("could not convert value to string: %v", err) - } - *d = str - case *[]string: - strSlice, err := toStringSlice(value) - if err != nil { - return fmt.Errorf("could not convert value to string slice: %v", err) - } - *d = strSlice - case *bool: - *d = cast.ToBool(value) - default: - return fmt.Errorf("unknown type for destination: %T", dst) - } - return nil -} - -// toStringSlice converts an interface (either a slice or single value) into -// a slice of strings. -func toStringSlice(value interface{}) ([]string, error) { - var sliceValues []interface{} - switch v := value.(type) { - case []interface{}: - sliceValues = v - case interface{}: - sliceValues = []interface{}{v} - default: - sliceValues = cast.ToSlice(value) - } - - out := []string{} - for _, v := range sliceValues { - str, err := toString(v) - if err != nil { - return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err) - } - out = append(out, str) - } - return out, nil -} - -// toString coerces a value into a string. -// If it is non-string, marshal it into JSON. -func toString(value interface{}) (string, error) { - if str, err := cast.ToStringE(value); err == nil { - return str, nil - } - - jsonStr, err := json.Marshal(value) - if err != nil { - return "", err - } - return string(jsonStr), nil -} diff --git a/pkg/providers/util/claim_extractor_test.go b/pkg/providers/util/claim_extractor_test.go index 4ce4606f..b8dbd34a 100644 --- a/pkg/providers/util/claim_extractor_test.go +++ b/pkg/providers/util/claim_extractor_test.go @@ -451,53 +451,6 @@ var _ = Describe("Claim Extractor Suite", func() { }), ) - type coerceClaimTableInput struct { - value interface{} - dst interface{} - expectedDst interface{} - expectedError error - } - - DescribeTable("coerceClaim", - func(in coerceClaimTableInput) { - err := coerceClaim(in.value, in.dst) - if in.expectedError != nil { - Expect(err).To(MatchError(in.expectedError)) - return - } - - Expect(err).ToNot(HaveOccurred()) - Expect(in.dst).To(Equal(in.expectedDst)) - }, - Entry("coerces a string to a string", coerceClaimTableInput{ - value: "some_string", - dst: stringPointer(""), - expectedDst: stringPointer("some_string"), - }), - Entry("coerces a slice to a string slice", coerceClaimTableInput{ - value: []interface{}{"a", "b"}, - dst: stringSlicePointer([]string{}), - expectedDst: stringSlicePointer([]string{"a", "b"}), - }), - Entry("coerces a bool to a bool", coerceClaimTableInput{ - value: true, - dst: boolPointer(false), - expectedDst: boolPointer(true), - }), - Entry("coerces a string to a bool", coerceClaimTableInput{ - value: "true", - dst: boolPointer(false), - expectedDst: boolPointer(true), - }), - Entry("coerces a map to a string", coerceClaimTableInput{ - value: map[string]interface{}{ - "foo": []interface{}{"bar", "baz"}, - }, - dst: stringPointer(""), - expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"), - }), - ) - It("should extract claims from a JWT response", func() { jwtResponsePayload := `{ "user": "jwtUser", @@ -605,10 +558,6 @@ func stringSlicePointer(in []string) *[]string { return &in } -func boolPointer(in bool) *bool { - return &in -} - // ****************************** // Different profile URL handlers // ****************************** diff --git a/pkg/util/util.go b/pkg/util/util.go index 0f3d70ad..905af0ef 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -5,6 +5,7 @@ import ( "crypto/rsa" "crypto/x509" "crypto/x509/pkix" + "encoding/json" "fmt" "math/big" "net" @@ -12,6 +13,8 @@ import ( "os" "strings" "time" + + "github.com/spf13/cast" ) func GetCertPool(paths []string, useSystemPool bool) (*x509.CertPool, error) { @@ -191,3 +194,66 @@ func RemoveDuplicateStr(strSlice []string) []string { } return list } + +// CoerceClaim tries to convert the value into the destination interface type. +// If it can convert the value, it will then store the value in the destination +// interface. +func CoerceClaim(value, dst interface{}) error { + switch d := dst.(type) { + case *string: + str, err := toString(value) + if err != nil { + return fmt.Errorf("could not convert value to string: %v", err) + } + *d = str + case *[]string: + strSlice, err := toStringSlice(value) + if err != nil { + return fmt.Errorf("could not convert value to string slice: %v", err) + } + *d = strSlice + case *bool: + *d = cast.ToBool(value) + default: + return fmt.Errorf("unknown type for destination: %T", dst) + } + return nil +} + +// toStringSlice converts an interface (either a slice or single value) into +// a slice of strings. +func toStringSlice(value interface{}) ([]string, error) { + var sliceValues []interface{} + switch v := value.(type) { + case []interface{}: + sliceValues = v + case interface{}: + sliceValues = []interface{}{v} + default: + sliceValues = cast.ToSlice(value) + } + + out := []string{} + for _, v := range sliceValues { + str, err := toString(v) + if err != nil { + return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err) + } + out = append(out, str) + } + return out, nil +} + +// toString coerces a value into a string. +// If it is non-string, marshal it into JSON. +func toString(value interface{}) (string, error) { + if str, err := cast.ToStringE(value); err == nil { + return str, nil + } + + jsonStr, err := json.Marshal(value) + if err != nil { + return "", err + } + return string(jsonStr), nil +} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 167c3e59..21afc0a5 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -2,8 +2,10 @@ package util import ( "crypto/x509" + "encoding/json" "encoding/pem" "os" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -253,3 +255,82 @@ func TestGetCertPool(t *testing.T) { assert.Error(t, err3) } } + +func stringPointer(s string) *string { + return &s +} + +func stringSlicePointer(s []string) *[]string { + return &s +} + +func boolPointer(b bool) *bool { + return &b +} + +type coerceClaimTableInput struct { + name string + value interface{} + dst interface{} + expectedDst interface{} + expectedError error +} + +func TestCoerceClaim(t *testing.T) { + tests := []coerceClaimTableInput{ + { + name: "coerces a string to a string", + value: "some_string", + dst: stringPointer(""), + expectedDst: stringPointer("some_string"), + }, + { + name: "coerces a slice to a string slice", + value: []interface{}{"a", "b"}, + dst: stringSlicePointer([]string{}), + expectedDst: stringSlicePointer([]string{"a", "b"}), + }, + { + name: "coerces a bool to a bool", + value: true, + dst: boolPointer(false), + expectedDst: boolPointer(true), + }, + { + name: "coerces a string to a bool", + value: "true", + dst: boolPointer(false), + expectedDst: boolPointer(true), + }, + { + name: "coerces a map to a string", + value: map[string]interface{}{ + "foo": []interface{}{"bar", "baz"}, + }, + dst: stringPointer(""), + expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CoerceClaim(tt.value, tt.dst) + if tt.expectedError != nil { + if err == nil || err.Error() != tt.expectedError.Error() { + t.Errorf("expected error %v, got %v", tt.expectedError, err) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(tt.dst, tt.expectedDst) { + gotJSON, _ := json.Marshal(tt.dst) + wantJSON, _ := json.Marshal(tt.expectedDst) + t.Errorf("expected dst to be %+v, got %+v", string(wantJSON), string(gotJSON)) + } + }) + } +}