211 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			211 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Go
		
	
	
	
package util
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"encoding/base64"
 | 
						|
	"encoding/json"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"github.com/bitly/go-simplejson"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
 | 
						|
	"github.com/spf13/cast"
 | 
						|
)
 | 
						|
 | 
						|
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
 | 
						|
// present, from the profile URL.
 | 
						|
type ClaimExtractor interface {
 | 
						|
	// GetClaim fetches a named claim and returns the value.
 | 
						|
	GetClaim(claim string) (interface{}, bool, error)
 | 
						|
 | 
						|
	// GetClaimInto fetches a named claim and puts the value into the destination.
 | 
						|
	GetClaimInto(claim string, dst interface{}) (bool, error)
 | 
						|
}
 | 
						|
 | 
						|
// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
 | 
						|
// If needed, it will use the profile URL to look up a claim if it isn't present
 | 
						|
// within the ID Token.
 | 
						|
func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
 | 
						|
	payload, err := parseJWT(idToken)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to parse ID Token: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	tokenClaims, err := simplejson.NewJson(payload)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to parse ID Token payload: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return &claimExtractor{
 | 
						|
		ctx:            ctx,
 | 
						|
		profileURL:     profileURL,
 | 
						|
		requestHeaders: profileRequestHeaders,
 | 
						|
		tokenClaims:    tokenClaims,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
// claimExtractor implements the ClaimExtractor interface
 | 
						|
type claimExtractor struct {
 | 
						|
	profileURL     *url.URL
 | 
						|
	ctx            context.Context
 | 
						|
	requestHeaders map[string][]string
 | 
						|
	tokenClaims    *simplejson.Json
 | 
						|
	profileClaims  *simplejson.Json
 | 
						|
}
 | 
						|
 | 
						|
// GetClaim will return the value claim if it exists.
 | 
						|
// It will only return an error if the profile URL needs to be fetched due to
 | 
						|
// the claim not being present in the ID Token.
 | 
						|
func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) {
 | 
						|
	if claim == "" {
 | 
						|
		return nil, false, nil
 | 
						|
	}
 | 
						|
 | 
						|
	if value := getClaimFrom(claim, c.tokenClaims); value != nil {
 | 
						|
		return value, true, nil
 | 
						|
	}
 | 
						|
 | 
						|
	if c.profileClaims == nil {
 | 
						|
		profileClaims, err := c.loadProfileClaims()
 | 
						|
		if err != nil {
 | 
						|
			return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
 | 
						|
		}
 | 
						|
 | 
						|
		c.profileClaims = profileClaims
 | 
						|
	}
 | 
						|
 | 
						|
	if value := getClaimFrom(claim, c.profileClaims); value != nil {
 | 
						|
		return value, true, nil
 | 
						|
	}
 | 
						|
 | 
						|
	return nil, false, nil
 | 
						|
}
 | 
						|
 | 
						|
// loadProfileClaims will fetch the profileURL using the provided headers as
 | 
						|
// authentication.
 | 
						|
func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
 | 
						|
	if c.profileURL == nil || c.profileURL.String() == "" || c.requestHeaders == nil {
 | 
						|
		// When no profileURL is set, we return a non-empty map so that
 | 
						|
		// we don't attempt to populate the profile claims again.
 | 
						|
		// If there are no headers, the request would be unauthorized so we also skip
 | 
						|
		// in this case too.
 | 
						|
		return simplejson.New(), nil
 | 
						|
	}
 | 
						|
 | 
						|
	claims, err := requests.New(c.profileURL.String()).
 | 
						|
		WithContext(c.ctx).
 | 
						|
		WithHeaders(c.requestHeaders).
 | 
						|
		Do().
 | 
						|
		UnmarshalJSON()
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("error making request to profile URL: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return claims, nil
 | 
						|
}
 | 
						|
 | 
						|
// GetClaimInto loads a claim and places it into the destination interface.
 | 
						|
// This will attempt to coerce the claim into the specified type.
 | 
						|
// If it cannot be coerced, an error may be returned.
 | 
						|
func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) {
 | 
						|
	value, exists, err := c.GetClaim(claim)
 | 
						|
	if err != nil {
 | 
						|
		return false, fmt.Errorf("could not get claim %q: %v", claim, err)
 | 
						|
	}
 | 
						|
	if !exists {
 | 
						|
		return false, nil
 | 
						|
	}
 | 
						|
	if err := coerceClaim(value, dst); err != nil {
 | 
						|
		return false, fmt.Errorf("could no coerce claim: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return true, nil
 | 
						|
}
 | 
						|
 | 
						|
// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
 | 
						|
// We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
 | 
						|
func parseJWT(p string) ([]byte, error) {
 | 
						|
	parts := strings.Split(p, ".")
 | 
						|
	if len(parts) < 2 {
 | 
						|
		return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
 | 
						|
	}
 | 
						|
	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
 | 
						|
	}
 | 
						|
	return payload, nil
 | 
						|
}
 | 
						|
 | 
						|
// getClaimFrom gets a claim from a Json object.
 | 
						|
// It can accept either a single claim name or a json path.
 | 
						|
// Paths with indexes are not supported.
 | 
						|
func getClaimFrom(claim string, src *simplejson.Json) interface{} {
 | 
						|
	claimParts := strings.Split(claim, ".")
 | 
						|
	return src.GetPath(claimParts...).Interface()
 | 
						|
}
 | 
						|
 | 
						|
// coerceClaim tries to convert the value into the destination interface type.
 | 
						|
// If it can convert the value, it will then store the value in the destination
 | 
						|
// interface.
 | 
						|
func coerceClaim(value, dst interface{}) error {
 | 
						|
	switch d := dst.(type) {
 | 
						|
	case *string:
 | 
						|
		str, err := toString(value)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("could not convert value to string: %v", err)
 | 
						|
		}
 | 
						|
		*d = str
 | 
						|
	case *[]string:
 | 
						|
		strSlice, err := toStringSlice(value)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("could not convert value to string slice: %v", err)
 | 
						|
		}
 | 
						|
		*d = strSlice
 | 
						|
	case *bool:
 | 
						|
		*d = cast.ToBool(value)
 | 
						|
	default:
 | 
						|
		return fmt.Errorf("unknown type for destination: %T", dst)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// toStringSlice converts an interface (either a slice or single value) into
 | 
						|
// a slice of strings.
 | 
						|
func toStringSlice(value interface{}) ([]string, error) {
 | 
						|
	var sliceValues []interface{}
 | 
						|
	switch v := value.(type) {
 | 
						|
	case []interface{}:
 | 
						|
		sliceValues = v
 | 
						|
	case interface{}:
 | 
						|
		sliceValues = []interface{}{v}
 | 
						|
	default:
 | 
						|
		sliceValues = cast.ToSlice(value)
 | 
						|
	}
 | 
						|
 | 
						|
	out := []string{}
 | 
						|
	for _, v := range sliceValues {
 | 
						|
		str, err := toString(v)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
 | 
						|
		}
 | 
						|
		out = append(out, str)
 | 
						|
	}
 | 
						|
	return out, nil
 | 
						|
}
 | 
						|
 | 
						|
// toString coerces a value into a string.
 | 
						|
// If it is non-string, marshal it into JSON.
 | 
						|
func toString(value interface{}) (string, error) {
 | 
						|
	if str, err := cast.ToStringE(value); err == nil {
 | 
						|
		return str, nil
 | 
						|
	}
 | 
						|
 | 
						|
	jsonStr, err := json.Marshal(value)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	return string(jsonStr), nil
 | 
						|
}
 |