refactor: extract coerceClaim logic into util
Signed-off-by: afsu <suaf2020@163.com>
This commit is contained in:
parent
230de6253a
commit
6b3f1c60d0
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
"github.com/pierrec/lz4/v4"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
|
@ -165,23 +166,9 @@ func (s *SessionState) GetClaim(claim string) []string {
|
|||
|
||||
func (s *SessionState) getAdditionalClaim(claim string) []string {
|
||||
if value, ok := s.AdditionalClaims[claim]; ok {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return []string{v}
|
||||
case []string:
|
||||
return v
|
||||
case []interface{}:
|
||||
result := make([]string, len(v))
|
||||
for i, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
result[i] = str
|
||||
} else {
|
||||
result[i] = fmt.Sprintf("%v", item)
|
||||
}
|
||||
}
|
||||
var result []string
|
||||
if err := util.CoerceClaim(value, &result); err == nil {
|
||||
return result
|
||||
default:
|
||||
return []string{fmt.Sprintf("%v", value)}
|
||||
}
|
||||
}
|
||||
return []string{}
|
||||
|
|
|
|||
|
|
@ -236,8 +236,7 @@ func TestEncodeAndDecodeSessionState(t *testing.T) {
|
|||
AdditionalClaims: map[string]interface{}{
|
||||
"custom_claim_1": "value1",
|
||||
"custom_claim_2": true,
|
||||
"custom_claim_3": int8(1),
|
||||
"custom_claim_4": []interface{}{"item1", "item2"},
|
||||
"custom_claim_3": []interface{}{"item1", "item2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -325,8 +324,7 @@ func TestGetClaim(t *testing.T) {
|
|||
AdditionalClaims: map[string]interface{}{
|
||||
"custom_claim_1": "value1",
|
||||
"custom_claim_2": true,
|
||||
"custom_claim_3": 1,
|
||||
"custom_claim_4": []string{"item1", "item2"},
|
||||
"custom_claim_3": []string{"item1", "item2"},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -345,8 +343,7 @@ func TestGetClaim(t *testing.T) {
|
|||
{"preferred_username", []string{"preferred_user"}},
|
||||
{"custom_claim_1", []string{"value1"}},
|
||||
{"custom_claim_2", []string{"true"}},
|
||||
{"custom_claim_3", []string{"1"}},
|
||||
{"custom_claim_4", []string{"item1", "item2"}},
|
||||
{"custom_claim_3", []string{"[\"item1\",\"item2\"]"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package util
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
|
|
@ -12,7 +11,7 @@ import (
|
|||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
)
|
||||
|
||||
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
|
||||
|
|
@ -132,7 +131,7 @@ func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, erro
|
|||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
if err := coerceClaim(value, dst); err != nil {
|
||||
if err := util.CoerceClaim(value, dst); err != nil {
|
||||
return false, fmt.Errorf("could no coerce claim: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -163,66 +162,3 @@ func getClaimFrom(claim string, src *simplejson.Json) interface{} {
|
|||
claimParts := strings.Split(claim, ".")
|
||||
return src.GetPath(claimParts...).Interface()
|
||||
}
|
||||
|
||||
// coerceClaim tries to convert the value into the destination interface type.
|
||||
// If it can convert the value, it will then store the value in the destination
|
||||
// interface.
|
||||
func coerceClaim(value, dst interface{}) error {
|
||||
switch d := dst.(type) {
|
||||
case *string:
|
||||
str, err := toString(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not convert value to string: %v", err)
|
||||
}
|
||||
*d = str
|
||||
case *[]string:
|
||||
strSlice, err := toStringSlice(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not convert value to string slice: %v", err)
|
||||
}
|
||||
*d = strSlice
|
||||
case *bool:
|
||||
*d = cast.ToBool(value)
|
||||
default:
|
||||
return fmt.Errorf("unknown type for destination: %T", dst)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// toStringSlice converts an interface (either a slice or single value) into
|
||||
// a slice of strings.
|
||||
func toStringSlice(value interface{}) ([]string, error) {
|
||||
var sliceValues []interface{}
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
sliceValues = v
|
||||
case interface{}:
|
||||
sliceValues = []interface{}{v}
|
||||
default:
|
||||
sliceValues = cast.ToSlice(value)
|
||||
}
|
||||
|
||||
out := []string{}
|
||||
for _, v := range sliceValues {
|
||||
str, err := toString(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
|
||||
}
|
||||
out = append(out, str)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// toString coerces a value into a string.
|
||||
// If it is non-string, marshal it into JSON.
|
||||
func toString(value interface{}) (string, error) {
|
||||
if str, err := cast.ToStringE(value); err == nil {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
jsonStr, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(jsonStr), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -451,53 +451,6 @@ var _ = Describe("Claim Extractor Suite", func() {
|
|||
}),
|
||||
)
|
||||
|
||||
type coerceClaimTableInput struct {
|
||||
value interface{}
|
||||
dst interface{}
|
||||
expectedDst interface{}
|
||||
expectedError error
|
||||
}
|
||||
|
||||
DescribeTable("coerceClaim",
|
||||
func(in coerceClaimTableInput) {
|
||||
err := coerceClaim(in.value, in.dst)
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(MatchError(in.expectedError))
|
||||
return
|
||||
}
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(in.dst).To(Equal(in.expectedDst))
|
||||
},
|
||||
Entry("coerces a string to a string", coerceClaimTableInput{
|
||||
value: "some_string",
|
||||
dst: stringPointer(""),
|
||||
expectedDst: stringPointer("some_string"),
|
||||
}),
|
||||
Entry("coerces a slice to a string slice", coerceClaimTableInput{
|
||||
value: []interface{}{"a", "b"},
|
||||
dst: stringSlicePointer([]string{}),
|
||||
expectedDst: stringSlicePointer([]string{"a", "b"}),
|
||||
}),
|
||||
Entry("coerces a bool to a bool", coerceClaimTableInput{
|
||||
value: true,
|
||||
dst: boolPointer(false),
|
||||
expectedDst: boolPointer(true),
|
||||
}),
|
||||
Entry("coerces a string to a bool", coerceClaimTableInput{
|
||||
value: "true",
|
||||
dst: boolPointer(false),
|
||||
expectedDst: boolPointer(true),
|
||||
}),
|
||||
Entry("coerces a map to a string", coerceClaimTableInput{
|
||||
value: map[string]interface{}{
|
||||
"foo": []interface{}{"bar", "baz"},
|
||||
},
|
||||
dst: stringPointer(""),
|
||||
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
|
||||
}),
|
||||
)
|
||||
|
||||
It("should extract claims from a JWT response", func() {
|
||||
jwtResponsePayload := `{
|
||||
"user": "jwtUser",
|
||||
|
|
@ -605,10 +558,6 @@ func stringSlicePointer(in []string) *[]string {
|
|||
return &in
|
||||
}
|
||||
|
||||
func boolPointer(in bool) *bool {
|
||||
return &in
|
||||
}
|
||||
|
||||
// ******************************
|
||||
// Different profile URL handlers
|
||||
// ******************************
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
|
|
@ -12,6 +13,8 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func GetCertPool(paths []string, useSystemPool bool) (*x509.CertPool, error) {
|
||||
|
|
@ -191,3 +194,66 @@ func RemoveDuplicateStr(strSlice []string) []string {
|
|||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// CoerceClaim tries to convert the value into the destination interface type.
|
||||
// If it can convert the value, it will then store the value in the destination
|
||||
// interface.
|
||||
func CoerceClaim(value, dst interface{}) error {
|
||||
switch d := dst.(type) {
|
||||
case *string:
|
||||
str, err := toString(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not convert value to string: %v", err)
|
||||
}
|
||||
*d = str
|
||||
case *[]string:
|
||||
strSlice, err := toStringSlice(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not convert value to string slice: %v", err)
|
||||
}
|
||||
*d = strSlice
|
||||
case *bool:
|
||||
*d = cast.ToBool(value)
|
||||
default:
|
||||
return fmt.Errorf("unknown type for destination: %T", dst)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// toStringSlice converts an interface (either a slice or single value) into
|
||||
// a slice of strings.
|
||||
func toStringSlice(value interface{}) ([]string, error) {
|
||||
var sliceValues []interface{}
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
sliceValues = v
|
||||
case interface{}:
|
||||
sliceValues = []interface{}{v}
|
||||
default:
|
||||
sliceValues = cast.ToSlice(value)
|
||||
}
|
||||
|
||||
out := []string{}
|
||||
for _, v := range sliceValues {
|
||||
str, err := toString(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
|
||||
}
|
||||
out = append(out, str)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// toString coerces a value into a string.
|
||||
// If it is non-string, marshal it into JSON.
|
||||
func toString(value interface{}) (string, error) {
|
||||
if str, err := cast.ToStringE(value); err == nil {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
jsonStr, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(jsonStr), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ package util
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -253,3 +255,82 @@ func TestGetCertPool(t *testing.T) {
|
|||
assert.Error(t, err3)
|
||||
}
|
||||
}
|
||||
|
||||
func stringPointer(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func stringSlicePointer(s []string) *[]string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func boolPointer(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
type coerceClaimTableInput struct {
|
||||
name string
|
||||
value interface{}
|
||||
dst interface{}
|
||||
expectedDst interface{}
|
||||
expectedError error
|
||||
}
|
||||
|
||||
func TestCoerceClaim(t *testing.T) {
|
||||
tests := []coerceClaimTableInput{
|
||||
{
|
||||
name: "coerces a string to a string",
|
||||
value: "some_string",
|
||||
dst: stringPointer(""),
|
||||
expectedDst: stringPointer("some_string"),
|
||||
},
|
||||
{
|
||||
name: "coerces a slice to a string slice",
|
||||
value: []interface{}{"a", "b"},
|
||||
dst: stringSlicePointer([]string{}),
|
||||
expectedDst: stringSlicePointer([]string{"a", "b"}),
|
||||
},
|
||||
{
|
||||
name: "coerces a bool to a bool",
|
||||
value: true,
|
||||
dst: boolPointer(false),
|
||||
expectedDst: boolPointer(true),
|
||||
},
|
||||
{
|
||||
name: "coerces a string to a bool",
|
||||
value: "true",
|
||||
dst: boolPointer(false),
|
||||
expectedDst: boolPointer(true),
|
||||
},
|
||||
{
|
||||
name: "coerces a map to a string",
|
||||
value: map[string]interface{}{
|
||||
"foo": []interface{}{"bar", "baz"},
|
||||
},
|
||||
dst: stringPointer(""),
|
||||
expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CoerceClaim(tt.value, tt.dst)
|
||||
if tt.expectedError != nil {
|
||||
if err == nil || err.Error() != tt.expectedError.Error() {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.dst, tt.expectedDst) {
|
||||
gotJSON, _ := json.Marshal(tt.dst)
|
||||
wantJSON, _ := json.Marshal(tt.expectedDst)
|
||||
t.Errorf("expected dst to be %+v, got %+v", string(wantJSON), string(gotJSON))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue