196 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			196 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
package providers
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/bitly/go-simplejson"
 | 
						|
	"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
 | 
						|
	"github.com/pusher/oauth2_proxy/pkg/logger"
 | 
						|
	"github.com/pusher/oauth2_proxy/pkg/requests"
 | 
						|
)
 | 
						|
 | 
						|
// AzureProvider represents an Azure based Identity Provider
 | 
						|
type AzureProvider struct {
 | 
						|
	*ProviderData
 | 
						|
	Tenant string
 | 
						|
}
 | 
						|
 | 
						|
// NewAzureProvider initiates a new AzureProvider
 | 
						|
func NewAzureProvider(p *ProviderData) *AzureProvider {
 | 
						|
	p.ProviderName = "Azure"
 | 
						|
 | 
						|
	if p.ProfileURL == nil || p.ProfileURL.String() == "" {
 | 
						|
		p.ProfileURL = &url.URL{
 | 
						|
			Scheme:   "https",
 | 
						|
			Host:     "graph.windows.net",
 | 
						|
			Path:     "/me",
 | 
						|
			RawQuery: "api-version=1.6",
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
 | 
						|
		p.ProtectedResource = &url.URL{
 | 
						|
			Scheme: "https",
 | 
						|
			Host:   "graph.windows.net",
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if p.Scope == "" {
 | 
						|
		p.Scope = "openid"
 | 
						|
	}
 | 
						|
 | 
						|
	return &AzureProvider{ProviderData: p}
 | 
						|
}
 | 
						|
 | 
						|
// Configure defaults the AzureProvider configuration options
 | 
						|
func (p *AzureProvider) Configure(tenant string) {
 | 
						|
	p.Tenant = tenant
 | 
						|
	if tenant == "" {
 | 
						|
		p.Tenant = "common"
 | 
						|
	}
 | 
						|
 | 
						|
	if p.LoginURL == nil || p.LoginURL.String() == "" {
 | 
						|
		p.LoginURL = &url.URL{
 | 
						|
			Scheme: "https",
 | 
						|
			Host:   "login.microsoftonline.com",
 | 
						|
			Path:   "/" + p.Tenant + "/oauth2/authorize"}
 | 
						|
	}
 | 
						|
	if p.RedeemURL == nil || p.RedeemURL.String() == "" {
 | 
						|
		p.RedeemURL = &url.URL{
 | 
						|
			Scheme: "https",
 | 
						|
			Host:   "login.microsoftonline.com",
 | 
						|
			Path:   "/" + p.Tenant + "/oauth2/token",
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (p *AzureProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
 | 
						|
	if code == "" {
 | 
						|
		err = errors.New("missing code")
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	params := url.Values{}
 | 
						|
	params.Add("redirect_uri", redirectURL)
 | 
						|
	params.Add("client_id", p.ClientID)
 | 
						|
	params.Add("client_secret", p.ClientSecret)
 | 
						|
	params.Add("code", code)
 | 
						|
	params.Add("grant_type", "authorization_code")
 | 
						|
	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
 | 
						|
		params.Add("resource", p.ProtectedResource.String())
 | 
						|
	}
 | 
						|
 | 
						|
	var req *http.Request
 | 
						|
	req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
						|
 | 
						|
	var resp *http.Response
 | 
						|
	resp, err = http.DefaultClient.Do(req)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	var body []byte
 | 
						|
	body, err = ioutil.ReadAll(resp.Body)
 | 
						|
	resp.Body.Close()
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	if resp.StatusCode != 200 {
 | 
						|
		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	var jsonResponse struct {
 | 
						|
		AccessToken  string `json:"access_token"`
 | 
						|
		RefreshToken string `json:"refresh_token"`
 | 
						|
		ExpiresOn    int64  `json:"expires_on,string"`
 | 
						|
		IDToken      string `json:"id_token"`
 | 
						|
	}
 | 
						|
	err = json.Unmarshal(body, &jsonResponse)
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	s = &sessions.SessionState{
 | 
						|
		AccessToken:  jsonResponse.AccessToken,
 | 
						|
		IDToken:      jsonResponse.IDToken,
 | 
						|
		CreatedAt:    time.Now(),
 | 
						|
		ExpiresOn:    time.Unix(jsonResponse.ExpiresOn, 0),
 | 
						|
		RefreshToken: jsonResponse.RefreshToken,
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func getAzureHeader(accessToken string) http.Header {
 | 
						|
	header := make(http.Header)
 | 
						|
	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
 | 
						|
	return header
 | 
						|
}
 | 
						|
 | 
						|
func getEmailFromJSON(json *simplejson.Json) (string, error) {
 | 
						|
	var email string
 | 
						|
	var err error
 | 
						|
 | 
						|
	email, err = json.Get("mail").String()
 | 
						|
 | 
						|
	if err != nil || email == "" {
 | 
						|
		otherMails, otherMailsErr := json.Get("otherMails").Array()
 | 
						|
		if len(otherMails) > 0 {
 | 
						|
			email = otherMails[0].(string)
 | 
						|
		}
 | 
						|
		err = otherMailsErr
 | 
						|
	}
 | 
						|
 | 
						|
	return email, err
 | 
						|
}
 | 
						|
 | 
						|
// GetEmailAddress returns the Account email address
 | 
						|
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
 | 
						|
	var email string
 | 
						|
	var err error
 | 
						|
 | 
						|
	if s.AccessToken == "" {
 | 
						|
		return "", errors.New("missing access token")
 | 
						|
	}
 | 
						|
	req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	req.Header = getAzureHeader(s.AccessToken)
 | 
						|
 | 
						|
	json, err := requests.Request(req)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	email, err = getEmailFromJSON(json)
 | 
						|
 | 
						|
	if err == nil && email != "" {
 | 
						|
		return email, err
 | 
						|
	}
 | 
						|
 | 
						|
	email, err = json.Get("userPrincipalName").String()
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		logger.Printf("failed making request %s", err)
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	if email == "" {
 | 
						|
		logger.Printf("failed to get email address")
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	return email, err
 | 
						|
}
 |