oauth2-proxy/pkg/providers/util/claim_extractor.go

165 lines
5.1 KiB
Go

package util
import (
"context"
"encoding/base64"
"fmt"
"mime"
"net/http"
"net/url"
"strings"
"github.com/bitly/go-simplejson"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
)
// ClaimExtractor is used to extract claim values from an ID Token, or, if not
// present, from the profile URL.
type ClaimExtractor interface {
// GetClaim fetches a named claim and returns the value.
GetClaim(claim string) (any, bool, error)
// GetClaimInto fetches a named claim and puts the value into the destination.
GetClaimInto(claim string, dst any) (bool, error)
}
// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token.
// If needed, it will use the profile URL to look up a claim if it isn't present
// within the ID Token.
func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) {
payload, err := parseJWT(idToken)
if err != nil {
return nil, fmt.Errorf("failed to parse ID Token: %w", err)
}
tokenClaims, err := simplejson.NewJson(payload)
if err != nil {
return nil, fmt.Errorf("failed to parse ID Token payload: %w", err)
}
return &claimExtractor{
ctx: ctx,
profileURL: profileURL,
requestHeaders: profileRequestHeaders,
tokenClaims: tokenClaims,
}, nil
}
// claimExtractor implements the ClaimExtractor interface
type claimExtractor struct {
profileURL *url.URL
ctx context.Context
requestHeaders map[string][]string
tokenClaims *simplejson.Json
profileClaims *simplejson.Json
}
// GetClaim will return the value claim if it exists.
// It will only return an error if the profile URL needs to be fetched due to
// the claim not being present in the ID Token.
func (c *claimExtractor) GetClaim(claim string) (any, bool, error) {
if claim == "" {
return nil, false, nil
}
if value := getClaimFrom(claim, c.tokenClaims); value != nil {
return value, true, nil
}
if c.profileClaims == nil {
profileClaims, err := c.loadProfileClaims()
if err != nil {
return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err)
}
c.profileClaims = profileClaims
}
if value := getClaimFrom(claim, c.profileClaims); value != nil {
return value, true, nil
}
return nil, false, nil
}
// loadProfileClaims will fetch the profileURL using the provided headers as
// authentication.
func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) {
if c.profileURL == nil || c.profileURL.String() == "" || c.requestHeaders == nil {
// When no profileURL is set, we return a non-empty map so that
// we don't attempt to populate the profile claims again.
// If there are no headers, the request would be unauthorized so we also skip
// in this case too.
return simplejson.New(), nil
}
builder := requests.New(c.profileURL.String()).
WithContext(c.ctx).
WithHeaders(c.requestHeaders).
Do()
// We first check if the result is a JWT token
// https://openid.net/specs/openid-connect-core-1_0-final.html#UserInfoResponse
mediaType, _, parseErr := mime.ParseMediaType(builder.Headers().Get("Content-Type"))
if parseErr == nil && mediaType == "application/jwt" {
// Decode and use JWT payload as profile claims
if pl, err := parseJWT(string(builder.Body())); err == nil {
return simplejson.NewJson(pl)
}
}
// Otherwise, process as normal JSON payload
claims, err := builder.UnmarshalSimpleJSON()
if err != nil {
return nil, fmt.Errorf("error making request to profile URL: %v", err)
}
return claims, nil
}
// GetClaimInto loads a claim and places it into the destination interface.
// This will attempt to coerce the claim into the specified type.
// If it cannot be coerced, an error may be returned.
func (c *claimExtractor) GetClaimInto(claim string, dst any) (bool, error) {
value, exists, err := c.GetClaim(claim)
if err != nil {
return false, fmt.Errorf("could not get claim %q: %v", claim, err)
}
if !exists {
return false, nil
}
if err := util.CoerceClaim(value, dst); err != nil {
return false, fmt.Errorf("could not coerce claim: %v", err)
}
return true, nil
}
// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130
// We use it to grab the raw ID Token payload so that we can parse it into the JSON library.
func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
}
// getClaimFrom gets a claim from a Json object.
// It can accept either a single claim name or a json path. The claim is always evaluated first as a single claim name.
// Paths with indexes are not supported.
func getClaimFrom(claim string, src *simplejson.Json) any {
if value, ok := src.CheckGet(claim); ok {
return value.Interface()
}
claimParts := strings.Split(claim, ".")
return src.GetPath(claimParts...).Interface()
}