229 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			229 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Go
		
	
	
	
| package util
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"mime"
 | |
| 	"net/http"
 | |
| 	"net/url"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/bitly/go-simplejson"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
 | |
| 	"github.com/spf13/cast"
 | |
| )
 | |
| 
 | |
| // ClaimExtractor is used to extract claim values from an ID Token, or, if not
 | |
| // present, from the profile URL.
 | |
| type ClaimExtractor interface {
 | |
| 	// GetClaim fetches a named claim and returns the value.
 | |
| 	GetClaim(claim string) (interface{}, bool, error)
 | |
| 
 | |
| 	// GetClaimInto fetches a named claim and puts the value into the destination.
 | |
| 	GetClaimInto(claim string, dst interface{}) (bool, error)
 | |
| }
 | |
| 
 | |
| // NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
 | |
| // If needed, it will use the profile URL to look up a claim if it isn't present
 | |
| // within the ID Token.
 | |
| func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
 | |
| 	payload, err := parseJWT(idToken)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to parse ID Token: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	tokenClaims, err := simplejson.NewJson(payload)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to parse ID Token payload: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return &claimExtractor{
 | |
| 		ctx:            ctx,
 | |
| 		profileURL:     profileURL,
 | |
| 		requestHeaders: profileRequestHeaders,
 | |
| 		tokenClaims:    tokenClaims,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // claimExtractor implements the ClaimExtractor interface
 | |
| type claimExtractor struct {
 | |
| 	profileURL     *url.URL
 | |
| 	ctx            context.Context
 | |
| 	requestHeaders map[string][]string
 | |
| 	tokenClaims    *simplejson.Json
 | |
| 	profileClaims  *simplejson.Json
 | |
| }
 | |
| 
 | |
| // GetClaim will return the value claim if it exists.
 | |
| // It will only return an error if the profile URL needs to be fetched due to
 | |
| // the claim not being present in the ID Token.
 | |
| func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) {
 | |
| 	if claim == "" {
 | |
| 		return nil, false, nil
 | |
| 	}
 | |
| 
 | |
| 	if value := getClaimFrom(claim, c.tokenClaims); value != nil {
 | |
| 		return value, true, nil
 | |
| 	}
 | |
| 
 | |
| 	if c.profileClaims == nil {
 | |
| 		profileClaims, err := c.loadProfileClaims()
 | |
| 		if err != nil {
 | |
| 			return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		c.profileClaims = profileClaims
 | |
| 	}
 | |
| 
 | |
| 	if value := getClaimFrom(claim, c.profileClaims); value != nil {
 | |
| 		return value, true, nil
 | |
| 	}
 | |
| 
 | |
| 	return nil, false, nil
 | |
| }
 | |
| 
 | |
| // loadProfileClaims will fetch the profileURL using the provided headers as
 | |
| // authentication.
 | |
| func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
 | |
| 	if c.profileURL == nil || c.profileURL.String() == "" || c.requestHeaders == nil {
 | |
| 		// When no profileURL is set, we return a non-empty map so that
 | |
| 		// we don't attempt to populate the profile claims again.
 | |
| 		// If there are no headers, the request would be unauthorized so we also skip
 | |
| 		// in this case too.
 | |
| 		return simplejson.New(), nil
 | |
| 	}
 | |
| 
 | |
| 	builder := requests.New(c.profileURL.String()).
 | |
| 		WithContext(c.ctx).
 | |
| 		WithHeaders(c.requestHeaders).
 | |
| 		Do()
 | |
| 
 | |
| 	// We first check if the result is a JWT token
 | |
| 	// https://openid.net/specs/openid-connect-core-1_0-final.html#UserInfoResponse
 | |
| 	mediaType, _, parseErr := mime.ParseMediaType(builder.Headers().Get("Content-Type"))
 | |
| 
 | |
| 	if parseErr == nil && mediaType == "application/jwt" {
 | |
| 		// Decode and use JWT payload as profile claims
 | |
| 		if pl, err := parseJWT(string(builder.Body())); err == nil {
 | |
| 			return simplejson.NewJson(pl)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Otherwise, process as normal JSON payload
 | |
| 	claims, err := builder.UnmarshalSimpleJSON()
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("error making request to profile URL: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return claims, nil
 | |
| }
 | |
| 
 | |
| // GetClaimInto loads a claim and places it into the destination interface.
 | |
| // This will attempt to coerce the claim into the specified type.
 | |
| // If it cannot be coerced, an error may be returned.
 | |
| func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) {
 | |
| 	value, exists, err := c.GetClaim(claim)
 | |
| 	if err != nil {
 | |
| 		return false, fmt.Errorf("could not get claim %q: %v", claim, err)
 | |
| 	}
 | |
| 	if !exists {
 | |
| 		return false, nil
 | |
| 	}
 | |
| 	if err := coerceClaim(value, dst); err != nil {
 | |
| 		return false, fmt.Errorf("could no coerce claim: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return true, nil
 | |
| }
 | |
| 
 | |
| // This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
 | |
| // We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
 | |
| func parseJWT(p string) ([]byte, error) {
 | |
| 	parts := strings.Split(p, ".")
 | |
| 	if len(parts) < 2 {
 | |
| 		return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
 | |
| 	}
 | |
| 	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
 | |
| 	}
 | |
| 	return payload, nil
 | |
| }
 | |
| 
 | |
| // getClaimFrom gets a claim from a Json object.
 | |
| // It can accept either a single claim name or a json path. The claim is always evaluated first as a single claim name.
 | |
| // Paths with indexes are not supported.
 | |
| func getClaimFrom(claim string, src *simplejson.Json) interface{} {
 | |
| 	if value, ok := src.CheckGet(claim); ok {
 | |
| 		return value.Interface()
 | |
| 	}
 | |
| 	claimParts := strings.Split(claim, ".")
 | |
| 	return src.GetPath(claimParts...).Interface()
 | |
| }
 | |
| 
 | |
| // coerceClaim tries to convert the value into the destination interface type.
 | |
| // If it can convert the value, it will then store the value in the destination
 | |
| // interface.
 | |
| func coerceClaim(value, dst interface{}) error {
 | |
| 	switch d := dst.(type) {
 | |
| 	case *string:
 | |
| 		str, err := toString(value)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("could not convert value to string: %v", err)
 | |
| 		}
 | |
| 		*d = str
 | |
| 	case *[]string:
 | |
| 		strSlice, err := toStringSlice(value)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("could not convert value to string slice: %v", err)
 | |
| 		}
 | |
| 		*d = strSlice
 | |
| 	case *bool:
 | |
| 		*d = cast.ToBool(value)
 | |
| 	default:
 | |
| 		return fmt.Errorf("unknown type for destination: %T", dst)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // toStringSlice converts an interface (either a slice or single value) into
 | |
| // a slice of strings.
 | |
| func toStringSlice(value interface{}) ([]string, error) {
 | |
| 	var sliceValues []interface{}
 | |
| 	switch v := value.(type) {
 | |
| 	case []interface{}:
 | |
| 		sliceValues = v
 | |
| 	case interface{}:
 | |
| 		sliceValues = []interface{}{v}
 | |
| 	default:
 | |
| 		sliceValues = cast.ToSlice(value)
 | |
| 	}
 | |
| 
 | |
| 	out := []string{}
 | |
| 	for _, v := range sliceValues {
 | |
| 		str, err := toString(v)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err)
 | |
| 		}
 | |
| 		out = append(out, str)
 | |
| 	}
 | |
| 	return out, nil
 | |
| }
 | |
| 
 | |
| // toString coerces a value into a string.
 | |
| // If it is non-string, marshal it into JSON.
 | |
| func toString(value interface{}) (string, error) {
 | |
| 	if str, err := cast.ToStringE(value); err == nil {
 | |
| 		return str, nil
 | |
| 	}
 | |
| 
 | |
| 	jsonStr, err := json.Marshal(value)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return string(jsonStr), nil
 | |
| }
 |