459 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			459 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
package providers
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"slices"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/bitly/go-simplejson"
 | 
						|
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
 | 
						|
)
 | 
						|
 | 
						|
// AzureProvider represents an Azure based Identity Provider
 | 
						|
type AzureProvider struct {
 | 
						|
	*ProviderData
 | 
						|
	Tenant          string
 | 
						|
	GraphGroupField string
 | 
						|
	isV2Endpoint    bool
 | 
						|
}
 | 
						|
 | 
						|
var _ Provider = (*AzureProvider)(nil)
 | 
						|
 | 
						|
const (
 | 
						|
	azureProviderName           = "Azure"
 | 
						|
	azureDefaultScope           = "openid"
 | 
						|
	azureDefaultGraphGroupField = "id"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	// Default Login URL for Azure. Pre-parsed URL of https://login.microsoftonline.com/common/oauth2/authorize.
 | 
						|
	azureDefaultLoginURL = &url.URL{
 | 
						|
		Scheme: "https",
 | 
						|
		Host:   "login.microsoftonline.com",
 | 
						|
		Path:   "/common/oauth2/authorize",
 | 
						|
	}
 | 
						|
 | 
						|
	// Default Redeem URL for Azure. Pre-parsed URL of https://login.microsoftonline.com/common/oauth2/token.
 | 
						|
	azureDefaultRedeemURL = &url.URL{
 | 
						|
		Scheme: "https",
 | 
						|
		Host:   "login.microsoftonline.com",
 | 
						|
		Path:   "/common/oauth2/token",
 | 
						|
	}
 | 
						|
 | 
						|
	// Default Profile URL for Azure. Pre-parsed URL of https://graph.microsoft.com/v1.0/me.
 | 
						|
	azureDefaultProfileURL = &url.URL{
 | 
						|
		Scheme: "https",
 | 
						|
		Host:   "graph.microsoft.com",
 | 
						|
		Path:   "/v1.0/me",
 | 
						|
	}
 | 
						|
)
 | 
						|
 | 
						|
// NewAzureProvider initiates a new AzureProvider
 | 
						|
func NewAzureProvider(p *ProviderData, opts options.AzureOptions) *AzureProvider {
 | 
						|
	p.setProviderDefaults(providerDefaults{
 | 
						|
		name:        azureProviderName,
 | 
						|
		loginURL:    azureDefaultLoginURL,
 | 
						|
		redeemURL:   azureDefaultRedeemURL,
 | 
						|
		profileURL:  azureDefaultProfileURL,
 | 
						|
		validateURL: nil,
 | 
						|
		scope:       azureDefaultScope,
 | 
						|
	})
 | 
						|
 | 
						|
	if p.ValidateURL == nil || p.ValidateURL.String() == "" {
 | 
						|
		p.ValidateURL = p.ProfileURL
 | 
						|
	}
 | 
						|
	p.getAuthorizationHeaderFunc = makeAzureHeader
 | 
						|
 | 
						|
	tenant := "common"
 | 
						|
	if opts.Tenant != "" {
 | 
						|
		tenant = opts.Tenant
 | 
						|
		overrideTenantURL(p.LoginURL, azureDefaultLoginURL, tenant, "authorize")
 | 
						|
		overrideTenantURL(p.RedeemURL, azureDefaultRedeemURL, tenant, "token")
 | 
						|
	}
 | 
						|
 | 
						|
	graphGroupField := azureDefaultGraphGroupField
 | 
						|
	if opts.GraphGroupField != "" {
 | 
						|
		graphGroupField = opts.GraphGroupField
 | 
						|
	}
 | 
						|
 | 
						|
	isV2Endpoint := false
 | 
						|
	if strings.Contains(p.LoginURL.String(), "v2.0") {
 | 
						|
		isV2Endpoint = true
 | 
						|
		azureV2GraphScope := fmt.Sprintf("https://%s/.default", p.ProfileURL.Host)
 | 
						|
 | 
						|
		if strings.Contains(p.Scope, " groups") {
 | 
						|
			logger.Print("WARNING: `groups` scope is not an accepted scope when using Azure OAuth V2 endpoint. Removing it from the scope list")
 | 
						|
			p.Scope = strings.ReplaceAll(p.Scope, " groups", "")
 | 
						|
		}
 | 
						|
 | 
						|
		if !strings.Contains(p.Scope, " "+azureV2GraphScope) {
 | 
						|
			// In order to be able to query MS Graph we must pass the ms graph default endpoint
 | 
						|
			p.Scope += " " + azureV2GraphScope
 | 
						|
		}
 | 
						|
 | 
						|
		if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
 | 
						|
			logger.Print("WARNING: `--resource` option has no effect when using the Azure OAuth V2 endpoint.")
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return &AzureProvider{
 | 
						|
		ProviderData:    p,
 | 
						|
		Tenant:          tenant,
 | 
						|
		GraphGroupField: graphGroupField,
 | 
						|
		isV2Endpoint:    isV2Endpoint,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
 | 
						|
	if current == nil || current.String() == "" || current.String() == defaultURL.String() {
 | 
						|
		*current = url.URL{
 | 
						|
			Scheme: "https",
 | 
						|
			Host:   current.Host,
 | 
						|
			Path:   "/" + tenant + "/oauth2/" + path}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func getMicrosoftGraphGroupsURL(profileURL *url.URL, graphGroupField string) *url.URL {
 | 
						|
 | 
						|
	selectStatement := "$select=displayName,id"
 | 
						|
	if !slices.Contains([]string{"displayName", "id"}, graphGroupField) {
 | 
						|
		selectStatement += "," + graphGroupField
 | 
						|
	}
 | 
						|
 | 
						|
	// Select only security groups. Due to the filter option, count param is mandatory even if unused otherwise
 | 
						|
	return &url.URL{
 | 
						|
		Scheme:   "https",
 | 
						|
		Host:     profileURL.Host,
 | 
						|
		Path:     "/v1.0/me/transitiveMemberOf",
 | 
						|
		RawQuery: "$count=true&$filter=securityEnabled+eq+true&" + selectStatement,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (p *AzureProvider) GetLoginURL(redirectURI, state, _ string, extraParams url.Values) string {
 | 
						|
	// In azure oauth v2 there is no resource param so add it only if V1 endpoint
 | 
						|
	// https://docs.microsoft.com/en-us/azure/active-directory/azuread-dev/azure-ad-endpoint-comparison#scopes-not-resources
 | 
						|
	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" && !p.isV2Endpoint {
 | 
						|
		extraParams.Add("resource", p.ProtectedResource.String())
 | 
						|
	}
 | 
						|
	a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams)
 | 
						|
	return a.String()
 | 
						|
}
 | 
						|
 | 
						|
// Redeem exchanges the OAuth2 authentication token for an ID token
 | 
						|
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code, codeVerifier string) (*sessions.SessionState, error) {
 | 
						|
	params, err := p.prepareRedeem(redirectURL, code, codeVerifier)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// blindly try json and x-www-form-urlencoded
 | 
						|
	var jsonResponse struct {
 | 
						|
		AccessToken  string `json:"access_token"`
 | 
						|
		RefreshToken string `json:"refresh_token"`
 | 
						|
		ExpiresOn    int64  `json:"expires_on,string"`
 | 
						|
		IDToken      string `json:"id_token"`
 | 
						|
	}
 | 
						|
 | 
						|
	err = requests.New(p.RedeemURL.String()).
 | 
						|
		WithContext(ctx).
 | 
						|
		WithMethod("POST").
 | 
						|
		WithBody(bytes.NewBufferString(params.Encode())).
 | 
						|
		SetHeader("Content-Type", "application/x-www-form-urlencoded").
 | 
						|
		Do().
 | 
						|
		UnmarshalInto(&jsonResponse)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	session := &sessions.SessionState{
 | 
						|
		AccessToken:  jsonResponse.AccessToken,
 | 
						|
		IDToken:      jsonResponse.IDToken,
 | 
						|
		RefreshToken: jsonResponse.RefreshToken,
 | 
						|
	}
 | 
						|
	session.CreatedAtNow()
 | 
						|
	session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
 | 
						|
 | 
						|
	err = p.extractClaimsIntoSession(ctx, session)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("unable to get email and/or groups claims from token: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return session, nil
 | 
						|
}
 | 
						|
 | 
						|
// EnrichSession enriches the session state with userID, mail and groups
 | 
						|
func (p *AzureProvider) EnrichSession(ctx context.Context, session *sessions.SessionState) error {
 | 
						|
	err := p.extractClaimsIntoSession(ctx, session)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		logger.Printf("unable to get email and/or groups claims from token: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if session.Email == "" {
 | 
						|
		email, err := p.getEmailFromProfileAPI(ctx, session.AccessToken)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("unable to get email address from profile URL: %v", err)
 | 
						|
		}
 | 
						|
		session.Email = email
 | 
						|
	}
 | 
						|
 | 
						|
	// If using the v2.0 oidc endpoint we're also querying Microsoft Graph
 | 
						|
	if p.isV2Endpoint {
 | 
						|
		groups, err := p.getGroupsFromProfileAPI(ctx, session)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("unable to get groups from Microsoft Graph: %v", err)
 | 
						|
		}
 | 
						|
		session.Groups = util.RemoveDuplicateStr(append(session.Groups, groups...))
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (p *AzureProvider) prepareRedeem(redirectURL, code, codeVerifier string) (url.Values, error) {
 | 
						|
	params := url.Values{}
 | 
						|
	if code == "" {
 | 
						|
		return params, ErrMissingCode
 | 
						|
	}
 | 
						|
	clientSecret, err := p.GetClientSecret()
 | 
						|
	if err != nil {
 | 
						|
		return params, err
 | 
						|
	}
 | 
						|
 | 
						|
	params.Add("redirect_uri", redirectURL)
 | 
						|
	params.Add("client_id", p.ClientID)
 | 
						|
	params.Add("client_secret", clientSecret)
 | 
						|
	params.Add("code", code)
 | 
						|
	params.Add("grant_type", "authorization_code")
 | 
						|
	if codeVerifier != "" {
 | 
						|
		params.Add("code_verifier", codeVerifier)
 | 
						|
	}
 | 
						|
 | 
						|
	// In azure oauth v2 there is no resource param so add it only if V1 endpoint
 | 
						|
	// https://docs.microsoft.com/en-us/azure/active-directory/azuread-dev/azure-ad-endpoint-comparison#scopes-not-resources
 | 
						|
	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" && !p.isV2Endpoint {
 | 
						|
		params.Add("resource", p.ProtectedResource.String())
 | 
						|
	}
 | 
						|
 | 
						|
	return params, nil
 | 
						|
}
 | 
						|
 | 
						|
// extractClaimsIntoSession tries to extract email and groups claims from either id_token or access token
 | 
						|
// when oidc verifier is configured
 | 
						|
func (p *AzureProvider) extractClaimsIntoSession(ctx context.Context, session *sessions.SessionState) error {
 | 
						|
 | 
						|
	var s *sessions.SessionState
 | 
						|
 | 
						|
	// First let's verify session token
 | 
						|
	if err := p.verifySessionToken(ctx, session); err != nil {
 | 
						|
		return fmt.Errorf("unable to verify token: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// https://github.com/oauth2-proxy/oauth2-proxy/pull/914#issuecomment-782285814
 | 
						|
	// https://github.com/AzureAD/azure-activedirectory-library-for-java/issues/117
 | 
						|
	// due to above issues, id_token may not be signed by AAD
 | 
						|
	// in that case, we will fallback to access token
 | 
						|
	var err error
 | 
						|
	s, err = p.buildSessionFromClaims(session.IDToken, session.AccessToken)
 | 
						|
	if err != nil || s.Email == "" {
 | 
						|
		s, err = p.buildSessionFromClaims(session.AccessToken, session.AccessToken)
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("unable to get claims from token: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	session.Email = s.Email
 | 
						|
	if s.Groups != nil {
 | 
						|
		session.Groups = s.Groups
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// verifySessionToken tries to validate id_token if present or access token when oidc verifier is configured
 | 
						|
func (p *AzureProvider) verifySessionToken(ctx context.Context, session *sessions.SessionState) error {
 | 
						|
	// Without a verifier there's no way to verify
 | 
						|
	if p.Verifier == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	if session.IDToken != "" {
 | 
						|
		if _, err := p.Verifier.Verify(ctx, session.IDToken); err != nil {
 | 
						|
			logger.Printf("unable to verify ID token, fallback to access token: %v", err)
 | 
						|
			if _, err = p.Verifier.Verify(ctx, session.AccessToken); err != nil {
 | 
						|
				return fmt.Errorf("unable to verify access token: %v", err)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	} else if _, err := p.Verifier.Verify(ctx, session.AccessToken); err != nil {
 | 
						|
		return fmt.Errorf("unable to verify access token: %v", err)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | 
						|
func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
 | 
						|
	if s == nil || s.RefreshToken == "" {
 | 
						|
		return false, nil
 | 
						|
	}
 | 
						|
 | 
						|
	err := p.redeemRefreshToken(ctx, s)
 | 
						|
	if err != nil {
 | 
						|
		return false, fmt.Errorf("unable to redeem refresh token: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return true, nil
 | 
						|
}
 | 
						|
 | 
						|
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
 | 
						|
	clientSecret, err := p.GetClientSecret()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	params := url.Values{}
 | 
						|
	params.Add("client_id", p.ClientID)
 | 
						|
	params.Add("client_secret", clientSecret)
 | 
						|
	params.Add("refresh_token", s.RefreshToken)
 | 
						|
	params.Add("grant_type", "refresh_token")
 | 
						|
 | 
						|
	var jsonResponse struct {
 | 
						|
		AccessToken  string `json:"access_token"`
 | 
						|
		RefreshToken string `json:"refresh_token"`
 | 
						|
		ExpiresOn    int64  `json:"expires_on,string"`
 | 
						|
		IDToken      string `json:"id_token"`
 | 
						|
	}
 | 
						|
 | 
						|
	err = requests.New(p.RedeemURL.String()).
 | 
						|
		WithContext(ctx).
 | 
						|
		WithMethod("POST").
 | 
						|
		WithBody(bytes.NewBufferString(params.Encode())).
 | 
						|
		SetHeader("Content-Type", "application/x-www-form-urlencoded").
 | 
						|
		Do().
 | 
						|
		UnmarshalInto(&jsonResponse)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	s.AccessToken = jsonResponse.AccessToken
 | 
						|
	s.IDToken = jsonResponse.IDToken
 | 
						|
	s.RefreshToken = jsonResponse.RefreshToken
 | 
						|
 | 
						|
	s.CreatedAtNow()
 | 
						|
	s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0))
 | 
						|
 | 
						|
	err = p.extractClaimsIntoSession(ctx, s)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		logger.Printf("unable to get email and/or groups claims from token: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func makeAzureHeader(accessToken string) http.Header {
 | 
						|
	return makeAuthorizationHeader(tokenTypeBearer, accessToken, nil)
 | 
						|
}
 | 
						|
 | 
						|
func (p *AzureProvider) getGroupsFromProfileAPI(ctx context.Context, s *sessions.SessionState) ([]string, error) {
 | 
						|
	if s.AccessToken == "" {
 | 
						|
		return nil, fmt.Errorf("missing access token")
 | 
						|
	}
 | 
						|
 | 
						|
	groupsURL := getMicrosoftGraphGroupsURL(p.ProfileURL, p.GraphGroupField).String()
 | 
						|
 | 
						|
	// Need and extra header while talking with MS Graph. For more context see
 | 
						|
	// https://docs.microsoft.com/en-us/graph/api/group-list-transitivememberof?view=graph-rest-1.0&tabs=http#request-headers
 | 
						|
	extraHeader := makeAzureHeader(s.AccessToken)
 | 
						|
	extraHeader.Add("ConsistencyLevel", "eventual")
 | 
						|
 | 
						|
	var groups []string
 | 
						|
 | 
						|
	for groupsURL != "" {
 | 
						|
		jsonRequest, err := requests.New(groupsURL).
 | 
						|
			WithContext(ctx).
 | 
						|
			WithHeaders(extraHeader).
 | 
						|
			Do().
 | 
						|
			UnmarshalSimpleJSON()
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("unable to unmarshal Microsoft Graph response: %v", err)
 | 
						|
 | 
						|
		}
 | 
						|
		groupsURL, err = jsonRequest.Get("@odata.nextLink").String()
 | 
						|
		if err != nil {
 | 
						|
			groupsURL = ""
 | 
						|
		}
 | 
						|
		groupsPage := getGroupsFromJSON(jsonRequest, p.GraphGroupField)
 | 
						|
		groups = append(groups, groupsPage...)
 | 
						|
	}
 | 
						|
 | 
						|
	return groups, nil
 | 
						|
}
 | 
						|
 | 
						|
func getGroupsFromJSON(json *simplejson.Json, graphGroupField string) []string {
 | 
						|
	groups := []string{}
 | 
						|
 | 
						|
	for i := range json.Get("value").MustArray() {
 | 
						|
		value := json.Get("value").GetIndex(i).Get(graphGroupField).MustString()
 | 
						|
		groups = append(groups, value)
 | 
						|
	}
 | 
						|
 | 
						|
	return groups
 | 
						|
}
 | 
						|
 | 
						|
func (p *AzureProvider) getEmailFromProfileAPI(ctx context.Context, accessToken string) (string, error) {
 | 
						|
	if accessToken == "" {
 | 
						|
		return "", fmt.Errorf("missing access token")
 | 
						|
	}
 | 
						|
 | 
						|
	json, err := requests.New(p.ProfileURL.String()).
 | 
						|
		WithContext(ctx).
 | 
						|
		WithHeaders(makeAzureHeader(accessToken)).
 | 
						|
		Do().
 | 
						|
		UnmarshalSimpleJSON()
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	email, err := getEmailFromJSON(json)
 | 
						|
	if email == "" {
 | 
						|
		return "", fmt.Errorf("empty email address: %v", err)
 | 
						|
	}
 | 
						|
	return email, nil
 | 
						|
}
 | 
						|
 | 
						|
func getEmailFromJSON(json *simplejson.Json) (string, error) {
 | 
						|
	email, err := json.Get("mail").String()
 | 
						|
 | 
						|
	if err != nil || email == "" {
 | 
						|
		otherMails, otherMailsErr := json.Get("otherMails").Array()
 | 
						|
		if len(otherMails) > 0 {
 | 
						|
			email = otherMails[0].(string)
 | 
						|
		}
 | 
						|
		err = otherMailsErr
 | 
						|
	}
 | 
						|
 | 
						|
	if err != nil || email == "" {
 | 
						|
		email, err = json.Get("userPrincipalName").String()
 | 
						|
		if err != nil {
 | 
						|
			logger.Errorf("unable to find userPrincipalName: %s", err)
 | 
						|
			return "", err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return email, nil
 | 
						|
}
 | 
						|
 | 
						|
// ValidateSession validates the AccessToken
 | 
						|
func (p *AzureProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
 | 
						|
	return validateToken(ctx, p, s.AccessToken, makeAzureHeader(s.AccessToken))
 | 
						|
}
 |