286 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			286 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
package providers
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/bitly/go-simplejson"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
 | 
						|
)
 | 
						|
 | 
						|
// AzureProvider represents an Azure based Identity Provider
 | 
						|
type AzureProvider struct {
 | 
						|
	*ProviderData
 | 
						|
	Tenant string
 | 
						|
}
 | 
						|
 | 
						|
var _ Provider = (*AzureProvider)(nil)
 | 
						|
 | 
						|
const (
 | 
						|
	azureProviderName = "Azure"
 | 
						|
	azureDefaultScope = "openid"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	// Default Login URL for Azure.
 | 
						|
	// Pre-parsed URL of https://login.microsoftonline.com/common/oauth2/authorize.
 | 
						|
	azureDefaultLoginURL = &url.URL{
 | 
						|
		Scheme: "https",
 | 
						|
		Host:   "login.microsoftonline.com",
 | 
						|
		Path:   "/common/oauth2/authorize",
 | 
						|
	}
 | 
						|
 | 
						|
	// Default Redeem URL for Azure.
 | 
						|
	// Pre-parsed URL of https://login.microsoftonline.com/common/oauth2/token.
 | 
						|
	azureDefaultRedeemURL = &url.URL{
 | 
						|
		Scheme: "https",
 | 
						|
		Host:   "login.microsoftonline.com",
 | 
						|
		Path:   "/common/oauth2/token",
 | 
						|
	}
 | 
						|
 | 
						|
	// Default Profile URL for Azure.
 | 
						|
	// Pre-parsed URL of https://graph.microsoft.com/v1.0/me.
 | 
						|
	azureDefaultProfileURL = &url.URL{
 | 
						|
		Scheme: "https",
 | 
						|
		Host:   "graph.microsoft.com",
 | 
						|
		Path:   "/v1.0/me",
 | 
						|
	}
 | 
						|
 | 
						|
	// Default ProtectedResource URL for Azure.
 | 
						|
	// Pre-parsed URL of https://graph.microsoft.com.
 | 
						|
	azureDefaultProtectResourceURL = &url.URL{
 | 
						|
		Scheme: "https",
 | 
						|
		Host:   "graph.microsoft.com",
 | 
						|
	}
 | 
						|
)
 | 
						|
 | 
						|
// NewAzureProvider initiates a new AzureProvider
 | 
						|
func NewAzureProvider(p *ProviderData) *AzureProvider {
 | 
						|
	p.setProviderDefaults(providerDefaults{
 | 
						|
		name:        azureProviderName,
 | 
						|
		loginURL:    azureDefaultLoginURL,
 | 
						|
		redeemURL:   azureDefaultRedeemURL,
 | 
						|
		profileURL:  azureDefaultProfileURL,
 | 
						|
		validateURL: nil,
 | 
						|
		scope:       azureDefaultScope,
 | 
						|
	})
 | 
						|
 | 
						|
	if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
 | 
						|
		p.ProtectedResource = azureDefaultProtectResourceURL
 | 
						|
	}
 | 
						|
	if p.ValidateURL == nil || p.ValidateURL.String() == "" {
 | 
						|
		p.ValidateURL = p.ProfileURL
 | 
						|
	}
 | 
						|
 | 
						|
	return &AzureProvider{
 | 
						|
		ProviderData: p,
 | 
						|
		Tenant:       "common",
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Configure defaults the AzureProvider configuration options
 | 
						|
func (p *AzureProvider) Configure(tenant string) {
 | 
						|
	if tenant == "" || tenant == "common" {
 | 
						|
		// tenant is empty or default, remain on the default "common" tenant
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	// Specific tennant specified, override the Login and RedeemURLs
 | 
						|
	p.Tenant = tenant
 | 
						|
	overrideTenantURL(p.LoginURL, azureDefaultLoginURL, tenant, "authorize")
 | 
						|
	overrideTenantURL(p.RedeemURL, azureDefaultRedeemURL, tenant, "token")
 | 
						|
}
 | 
						|
 | 
						|
func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
 | 
						|
	if current == nil || current.String() == "" || current.String() == defaultURL.String() {
 | 
						|
		*current = url.URL{
 | 
						|
			Scheme: "https",
 | 
						|
			Host:   "login.microsoftonline.com",
 | 
						|
			Path:   "/" + tenant + "/oauth2/" + path}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Redeem exchanges the OAuth2 authentication token for an ID token
 | 
						|
func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
 | 
						|
	if code == "" {
 | 
						|
		return nil, ErrMissingCode
 | 
						|
	}
 | 
						|
	clientSecret, err := p.GetClientSecret()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	params := url.Values{}
 | 
						|
	params.Add("redirect_uri", redirectURL)
 | 
						|
	params.Add("client_id", p.ClientID)
 | 
						|
	params.Add("client_secret", clientSecret)
 | 
						|
	params.Add("code", code)
 | 
						|
	params.Add("grant_type", "authorization_code")
 | 
						|
	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
 | 
						|
		params.Add("resource", p.ProtectedResource.String())
 | 
						|
	}
 | 
						|
 | 
						|
	// blindly try json and x-www-form-urlencoded
 | 
						|
	var jsonResponse struct {
 | 
						|
		AccessToken  string `json:"access_token"`
 | 
						|
		RefreshToken string `json:"refresh_token"`
 | 
						|
		ExpiresOn    int64  `json:"expires_on,string"`
 | 
						|
		IDToken      string `json:"id_token"`
 | 
						|
	}
 | 
						|
 | 
						|
	err = requests.New(p.RedeemURL.String()).
 | 
						|
		WithContext(ctx).
 | 
						|
		WithMethod("POST").
 | 
						|
		WithBody(bytes.NewBufferString(params.Encode())).
 | 
						|
		SetHeader("Content-Type", "application/x-www-form-urlencoded").
 | 
						|
		Do().
 | 
						|
		UnmarshalInto(&jsonResponse)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	created := time.Now()
 | 
						|
	expires := time.Unix(jsonResponse.ExpiresOn, 0)
 | 
						|
 | 
						|
	return &sessions.SessionState{
 | 
						|
		AccessToken:  jsonResponse.AccessToken,
 | 
						|
		IDToken:      jsonResponse.IDToken,
 | 
						|
		CreatedAt:    &created,
 | 
						|
		ExpiresOn:    &expires,
 | 
						|
		RefreshToken: jsonResponse.RefreshToken,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
// RefreshSessionIfNeeded checks if the session has expired and uses the
 | 
						|
// RefreshToken to fetch a new ID token if required
 | 
						|
func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
 | 
						|
	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
 | 
						|
		return false, nil
 | 
						|
	}
 | 
						|
 | 
						|
	origExpiration := s.ExpiresOn
 | 
						|
 | 
						|
	err := p.redeemRefreshToken(ctx, s)
 | 
						|
	if err != nil {
 | 
						|
		return false, fmt.Errorf("unable to redeem refresh token: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
 | 
						|
	return true, nil
 | 
						|
}
 | 
						|
 | 
						|
func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
 | 
						|
	params := url.Values{}
 | 
						|
	params.Add("client_id", p.ClientID)
 | 
						|
	params.Add("client_secret", p.ClientSecret)
 | 
						|
	params.Add("refresh_token", s.RefreshToken)
 | 
						|
	params.Add("grant_type", "refresh_token")
 | 
						|
 | 
						|
	var jsonResponse struct {
 | 
						|
		AccessToken  string `json:"access_token"`
 | 
						|
		RefreshToken string `json:"refresh_token"`
 | 
						|
		ExpiresOn    int64  `json:"expires_on,string"`
 | 
						|
		IDToken      string `json:"id_token"`
 | 
						|
	}
 | 
						|
 | 
						|
	err = requests.New(p.RedeemURL.String()).
 | 
						|
		WithContext(ctx).
 | 
						|
		WithMethod("POST").
 | 
						|
		WithBody(bytes.NewBufferString(params.Encode())).
 | 
						|
		SetHeader("Content-Type", "application/x-www-form-urlencoded").
 | 
						|
		Do().
 | 
						|
		UnmarshalInto(&jsonResponse)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	now := time.Now()
 | 
						|
	expires := time.Unix(jsonResponse.ExpiresOn, 0)
 | 
						|
	s.AccessToken = jsonResponse.AccessToken
 | 
						|
	s.IDToken = jsonResponse.IDToken
 | 
						|
	s.RefreshToken = jsonResponse.RefreshToken
 | 
						|
	s.CreatedAt = &now
 | 
						|
	s.ExpiresOn = &expires
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func makeAzureHeader(accessToken string) http.Header {
 | 
						|
	return makeAuthorizationHeader(tokenTypeBearer, accessToken, nil)
 | 
						|
}
 | 
						|
 | 
						|
func getEmailFromJSON(json *simplejson.Json) (string, error) {
 | 
						|
	var email string
 | 
						|
	var err error
 | 
						|
 | 
						|
	email, err = json.Get("mail").String()
 | 
						|
 | 
						|
	if err != nil || email == "" {
 | 
						|
		otherMails, otherMailsErr := json.Get("otherMails").Array()
 | 
						|
		if len(otherMails) > 0 {
 | 
						|
			email = otherMails[0].(string)
 | 
						|
		}
 | 
						|
		err = otherMailsErr
 | 
						|
	}
 | 
						|
 | 
						|
	return email, err
 | 
						|
}
 | 
						|
 | 
						|
// GetEmailAddress returns the Account email address
 | 
						|
func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
 | 
						|
	var email string
 | 
						|
	var err error
 | 
						|
 | 
						|
	if s.AccessToken == "" {
 | 
						|
		return "", errors.New("missing access token")
 | 
						|
	}
 | 
						|
 | 
						|
	json, err := requests.New(p.ProfileURL.String()).
 | 
						|
		WithContext(ctx).
 | 
						|
		WithHeaders(makeAzureHeader(s.AccessToken)).
 | 
						|
		Do().
 | 
						|
		UnmarshalJSON()
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	email, err = getEmailFromJSON(json)
 | 
						|
	if err == nil && email != "" {
 | 
						|
		return email, err
 | 
						|
	}
 | 
						|
 | 
						|
	email, err = json.Get("userPrincipalName").String()
 | 
						|
	if err != nil {
 | 
						|
		logger.Errorf("failed making request %s", err)
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	if email == "" {
 | 
						|
		logger.Errorf("failed to get email address")
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	return email, err
 | 
						|
}
 | 
						|
 | 
						|
func (p *AzureProvider) GetLoginURL(redirectURI, state string) string {
 | 
						|
	extraParams := url.Values{}
 | 
						|
	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
 | 
						|
		extraParams.Add("resource", p.ProtectedResource.String())
 | 
						|
	}
 | 
						|
	a := makeLoginURL(p.ProviderData, redirectURI, state, extraParams)
 | 
						|
	return a.String()
 | 
						|
}
 | 
						|
 | 
						|
// ValidateSession validates the AccessToken
 | 
						|
func (p *AzureProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
 | 
						|
	return validateToken(ctx, p, s.AccessToken, makeAzureHeader(s.AccessToken))
 | 
						|
}
 |