Github provider
This commit is contained in:
		
							parent
							
								
									8471f972e1
								
							
						
					
					
						commit
						37b38dd2f4
					
				|  | @ -101,6 +101,8 @@ Usage of google_auth_proxy: | ||||||
|   -version=false: print version string |   -version=false: print version string | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  | See below for provider specific options | ||||||
|  | 
 | ||||||
| ### Environment variables | ### Environment variables | ||||||
| 
 | 
 | ||||||
| The environment variables `GOOGLE_AUTH_PROXY_CLIENT_ID`, `GOOGLE_AUTH_PROXY_CLIENT_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_DOMAIN` and `GOOGLE_AUTH_PROXY_COOKIE_EXPIRE` can be used in place of the corresponding command-line arguments. | The environment variables `GOOGLE_AUTH_PROXY_CLIENT_ID`, `GOOGLE_AUTH_PROXY_CLIENT_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_DOMAIN` and `GOOGLE_AUTH_PROXY_COOKIE_EXPIRE` can be used in place of the corresponding command-line arguments. | ||||||
|  | @ -173,6 +175,10 @@ directive. Right now this includes: | ||||||
| * `myusa` - The [MyUSA](https://alpha.my.usa.gov) authentication service | * `myusa` - The [MyUSA](https://alpha.my.usa.gov) authentication service | ||||||
|   ([GitHub](https://github.com/18F/myusa)) |   ([GitHub](https://github.com/18F/myusa)) | ||||||
| * `linkedin` - The [LinkedIn](https://developer.linkedin.com/docs/signin-with-linkedin) Sign In service. | * `linkedin` - The [LinkedIn](https://developer.linkedin.com/docs/signin-with-linkedin) Sign In service. | ||||||
|  | * `github` - Via [Github][https://github.com/settings/developers] OAuth App. Also supports restricting via org and team. | ||||||
|  | 
 | ||||||
|  |     -github-org="": restrict logins to members of this organisation | ||||||
|  |     -github-team="": restrict logins to members of this team | ||||||
| 
 | 
 | ||||||
| ## Adding a new Provider | ## Adding a new Provider | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -10,8 +10,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func Request(req *http.Request) (*simplejson.Json, error) { | func Request(req *http.Request) (*simplejson.Json, error) { | ||||||
| 	httpclient := &http.Client{} | 	resp, err := http.DefaultClient.Do(req) | ||||||
| 	resp, err := httpclient.Do(req) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
							
								
								
									
										3
									
								
								main.go
								
								
								
								
							
							
						
						
									
										3
									
								
								main.go
								
								
								
								
							|  | @ -17,6 +17,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func main() { | func main() { | ||||||
|  | 	log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) | ||||||
| 	flagSet := flag.NewFlagSet("google_auth_proxy", flag.ExitOnError) | 	flagSet := flag.NewFlagSet("google_auth_proxy", flag.ExitOnError) | ||||||
| 
 | 
 | ||||||
| 	googleAppsDomains := StringArray{} | 	googleAppsDomains := StringArray{} | ||||||
|  | @ -35,6 +36,8 @@ func main() { | ||||||
| 	flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") | 	flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") | ||||||
| 
 | 
 | ||||||
| 	flagSet.Var(&googleAppsDomains, "google-apps-domain", "authenticate against the given Google apps domain (may be given multiple times)") | 	flagSet.Var(&googleAppsDomains, "google-apps-domain", "authenticate against the given Google apps domain (may be given multiple times)") | ||||||
|  | 	flagSet.String("github-org", "", "restrict logins to members of this organisation") | ||||||
|  | 	flagSet.String("github-team", "", "restrict logins to members of this team") | ||||||
| 	flagSet.String("client-id", "", "the Google OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"") | 	flagSet.String("client-id", "", "the Google OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"") | ||||||
| 	flagSet.String("client-secret", "", "the OAuth Client Secret") | 	flagSet.String("client-secret", "", "the OAuth Client Secret") | ||||||
| 	flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") | 	flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") | ||||||
|  |  | ||||||
|  | @ -1,7 +1,6 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bytes" |  | ||||||
| 	"crypto/aes" | 	"crypto/aes" | ||||||
| 	"crypto/cipher" | 	"crypto/cipher" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
|  | @ -17,7 +16,6 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/google_auth_proxy/api" |  | ||||||
| 	"github.com/bitly/google_auth_proxy/providers" | 	"github.com/bitly/google_auth_proxy/providers" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -39,7 +37,6 @@ type OauthProxy struct { | ||||||
| 
 | 
 | ||||||
| 	redirectUrl         *url.URL // the url to receive requests at
 | 	redirectUrl         *url.URL // the url to receive requests at
 | ||||||
| 	provider            providers.Provider | 	provider            providers.Provider | ||||||
| 	oauthRedemptionUrl  *url.URL // endpoint to redeem the code
 |  | ||||||
| 	oauthLoginUrl       *url.URL // to redirect the user to
 | 	oauthLoginUrl       *url.URL // to redirect the user to
 | ||||||
| 	oauthValidateUrl    *url.URL // to validate the access token
 | 	oauthValidateUrl    *url.URL // to validate the access token
 | ||||||
| 	oauthScope          string | 	oauthScope          string | ||||||
|  | @ -147,7 +144,6 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | ||||||
| 		clientSecret:     opts.ClientSecret, | 		clientSecret:     opts.ClientSecret, | ||||||
| 		oauthScope:       opts.provider.Data().Scope, | 		oauthScope:       opts.provider.Data().Scope, | ||||||
| 		provider:         opts.provider, | 		provider:         opts.provider, | ||||||
| 		oauthRedemptionUrl: opts.provider.Data().RedeemUrl, |  | ||||||
| 		oauthLoginUrl:    opts.provider.Data().LoginUrl, | 		oauthLoginUrl:    opts.provider.Data().LoginUrl, | ||||||
| 		oauthValidateUrl: opts.provider.Data().ValidateUrl, | 		oauthValidateUrl: opts.provider.Data().ValidateUrl, | ||||||
| 		serveMux:         serveMux, | 		serveMux:         serveMux, | ||||||
|  | @ -200,29 +196,13 @@ func (p *OauthProxy) redeemCode(host, code string) (string, string, error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return "", "", errors.New("missing code") | 		return "", "", errors.New("missing code") | ||||||
| 	} | 	} | ||||||
| 	params := url.Values{} | 	redirectUri := p.GetRedirectUrl(host) | ||||||
| 	params.Add("redirect_uri", p.GetRedirectUrl(host)) | 	body, access_token, err := p.provider.Redeem(redirectUri, code) | ||||||
| 	params.Add("client_id", p.clientID) |  | ||||||
| 	params.Add("client_secret", p.clientSecret) |  | ||||||
| 	params.Add("code", code) |  | ||||||
| 	params.Add("grant_type", "authorization_code") |  | ||||||
| 	req, err := http.NewRequest("POST", p.oauthRedemptionUrl.String(), bytes.NewBufferString(params.Encode())) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Printf("failed building request %s", err.Error()) |  | ||||||
| 		return "", "", err |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |  | ||||||
| 	json, err := api.Request(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Printf("failed making request %s", err) |  | ||||||
| 		return "", "", err |  | ||||||
| 	} |  | ||||||
| 	access_token, err := json.Get("access_token").String() |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", "", err | 		return "", "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	email, err := p.provider.GetEmailAddress(json, access_token) | 	email, err := p.provider.GetEmailAddress(body, access_token) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", "", err | 		return "", "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -1,10 +1,10 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/bitly/google_auth_proxy/providers" | 	"github.com/bitly/google_auth_proxy/providers" | ||||||
| 	"github.com/bmizerany/assert" | 	"github.com/bmizerany/assert" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
|  | 	"log" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | @ -15,6 +15,11 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | func init() { | ||||||
|  | 	log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestNewReverseProxy(t *testing.T) { | func TestNewReverseProxy(t *testing.T) { | ||||||
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		w.WriteHeader(200) | 		w.WriteHeader(200) | ||||||
|  | @ -89,8 +94,7 @@ type TestProvider struct { | ||||||
| 	ValidToken   bool | 	ValidToken   bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json, | func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||||
| 	unused_access_token string) (string, error) { |  | ||||||
| 	return tp.EmailAddress, nil | 	return tp.EmailAddress, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -113,16 +117,15 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes | ||||||
| 
 | 
 | ||||||
| 	t.provider_server = httptest.NewServer( | 	t.provider_server = httptest.NewServer( | ||||||
| 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 			log.Printf("%#v", r) | ||||||
| 			url := r.URL | 			url := r.URL | ||||||
| 			payload := "" | 			payload := "" | ||||||
| 			switch url.Path { | 			switch url.Path { | ||||||
| 			case "/oauth/token": | 			case "/oauth/token": | ||||||
| 				payload = `{"access_token": "my_auth_token"}` | 				payload = `{"access_token": "my_auth_token"}` | ||||||
| 			default: | 			default: | ||||||
| 				token_header := r.Header["X-Forwarded-Access-Token"] | 				payload = r.Header.Get("X-Forwarded-Access-Token") | ||||||
| 				if len(token_header) != 0 { | 				if payload == "" { | ||||||
| 					payload = token_header[0] |  | ||||||
| 				} else { |  | ||||||
| 					payload = "No access token found." | 					payload = "No access token found." | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | @ -189,8 +192,7 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, | ||||||
| 	return rw.Code, rw.HeaderMap["Set-Cookie"][0] | 	return rw.Code, rw.HeaderMap["Set-Cookie"][0] | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (pat_test *PassAccessTokenTest) getRootEndpoint( | func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) { | ||||||
| 	cookie string) (http_code int, access_token string) { |  | ||||||
| 	cookie_key := pat_test.proxy.CookieKey | 	cookie_key := pat_test.proxy.CookieKey | ||||||
| 	var value string | 	var value string | ||||||
| 	key_prefix := cookie_key + "=" | 	key_prefix := cookie_key + "=" | ||||||
|  |  | ||||||
|  | @ -19,6 +19,8 @@ type Options struct { | ||||||
| 
 | 
 | ||||||
| 	AuthenticatedEmailsFile string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` | 	AuthenticatedEmailsFile string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` | ||||||
| 	GoogleAppsDomains       []string `flag:"google-apps-domain" cfg:"google_apps_domains"` | 	GoogleAppsDomains       []string `flag:"google-apps-domain" cfg:"google_apps_domains"` | ||||||
|  | 	GitHubOrg               string   `flag:"github-org" cfg:"github_org"` | ||||||
|  | 	GitHubTeam              string   `flag:"github-team" cfg:"github_team"` | ||||||
| 	HtpasswdFile            string   `flag:"htpasswd-file" cfg:"htpasswd_file"` | 	HtpasswdFile            string   `flag:"htpasswd-file" cfg:"htpasswd_file"` | ||||||
| 	DisplayHtpasswdForm     bool     `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` | 	DisplayHtpasswdForm     bool     `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` | ||||||
| 	CustomTemplatesDir      string   `flag:"custom-templates-dir" cfg:"custom_templates_dir"` | 	CustomTemplatesDir      string   `flag:"custom-templates-dir" cfg:"custom_templates_dir"` | ||||||
|  | @ -153,11 +155,16 @@ func (o *Options) Validate() error { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func parseProviderInfo(o *Options, msgs []string) []string { | func parseProviderInfo(o *Options, msgs []string) []string { | ||||||
| 	p := &providers.ProviderData{Scope: o.Scope} | 	p := &providers.ProviderData{Scope: o.Scope, ClientID: o.ClientID, ClientSecret: o.ClientSecret} | ||||||
| 	p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs) | 	p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs) | ||||||
| 	p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs) | 	p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs) | ||||||
| 	p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs) | 	p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs) | ||||||
| 	p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs) | 	p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs) | ||||||
|  | 
 | ||||||
| 	o.provider = providers.New(o.Provider, p) | 	o.provider = providers.New(o.Provider, p) | ||||||
|  | 	switch p := o.provider.(type) { | ||||||
|  | 	case *providers.GitHubProvider: | ||||||
|  | 		p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam) | ||||||
|  | 	} | ||||||
| 	return msgs | 	return msgs | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,136 @@ | ||||||
|  | package providers | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type GitHubProvider struct { | ||||||
|  | 	*ProviderData | ||||||
|  | 	Org  string | ||||||
|  | 	Team string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewGitHubProvider(p *ProviderData) *GitHubProvider { | ||||||
|  | 	p.ProviderName = "GitHub" | ||||||
|  | 	if p.LoginUrl.String() == "" { | ||||||
|  | 		p.LoginUrl = &url.URL{ | ||||||
|  | 			Scheme: "https", | ||||||
|  | 			Host:   "github.com", | ||||||
|  | 			Path:   "/login/oauth/authorize", | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if p.RedeemUrl.String() == "" { | ||||||
|  | 		p.RedeemUrl = &url.URL{ | ||||||
|  | 			Scheme: "https", | ||||||
|  | 			Host:   "github.com", | ||||||
|  | 			Path:   "/login/oauth/access_token", | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if p.ValidateUrl.String() == "" { | ||||||
|  | 		p.ValidateUrl = &url.URL{ | ||||||
|  | 			Scheme: "https", | ||||||
|  | 			Host:   "api.github.com", | ||||||
|  | 			Path:   "/user/emails", | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if p.Scope == "" { | ||||||
|  | 		p.Scope = "user:email" | ||||||
|  | 	} | ||||||
|  | 	return &GitHubProvider{ProviderData: p} | ||||||
|  | } | ||||||
|  | func (p *GitHubProvider) SetOrgTeam(org, team string) { | ||||||
|  | 	p.Org = org | ||||||
|  | 	p.Team = team | ||||||
|  | 	if org != "" || team != "" { | ||||||
|  | 		p.Scope += " read:org" | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { | ||||||
|  | 
 | ||||||
|  | 	var teams []struct { | ||||||
|  | 		Name string `json:"name"` | ||||||
|  | 		Slug string `json:"slug"` | ||||||
|  | 		Org  struct { | ||||||
|  | 			Login string `json:"login"` | ||||||
|  | 		} `json:"organization"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	params := url.Values{ | ||||||
|  | 		"access_token": {accessToken}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	req, _ := http.NewRequest("GET", "https://api.github.com/user/teams?"+params.Encode(), nil) | ||||||
|  | 	req.Header.Set("Accept", "application/vnd.github.moondragon+json") | ||||||
|  | 	resp, err := http.DefaultClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	body, err := ioutil.ReadAll(resp.Body) | ||||||
|  | 	resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := json.Unmarshal(body, &teams); err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, team := range teams { | ||||||
|  | 		if p.Org == team.Org.Login { | ||||||
|  | 			if p.Team == "" || p.Team == team.Slug { | ||||||
|  | 				return true, nil | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||||
|  | 
 | ||||||
|  | 	var emails []struct { | ||||||
|  | 		Email   string `json:"email"` | ||||||
|  | 		Primary bool   `json:"primary"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	params := url.Values{ | ||||||
|  | 		"access_token": {access_token}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// if we require an Org or Team, check that first
 | ||||||
|  | 	if p.Org != "" || p.Team != "" { | ||||||
|  | 		if ok, err := p.hasOrgAndTeam(access_token); err != nil || !ok { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resp, err := http.DefaultClient.Get("https://api.github.com/user/emails?" + params.Encode()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	body, err = ioutil.ReadAll(resp.Body) | ||||||
|  | 	resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := json.Unmarshal(body, &emails); err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, email := range emails { | ||||||
|  | 		if email.Primary { | ||||||
|  | 			return email.Email, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return "", nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *GitHubProvider) ValidateToken(access_token string) bool { | ||||||
|  | 	return validateToken(p, access_token, nil) | ||||||
|  | } | ||||||
|  | @ -2,10 +2,10 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 |  | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type GoogleProvider struct { | type GoogleProvider struct { | ||||||
|  | @ -35,28 +35,34 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { | ||||||
| 	return &GoogleProvider{ProviderData: p} | 	return &GoogleProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *GoogleProvider) GetEmailAddress(auth_response *simplejson.Json, | func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||||
| 	unused_access_token string) (string, error) { | 	var response struct { | ||||||
| 	idToken, err := auth_response.Get("id_token").String() | 		IdToken string `json:"id_token"` | ||||||
| 	if err != nil { | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := json.Unmarshal(body, &response); err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	// id_token is a base64 encode ID token payload
 | 	// id_token is a base64 encode ID token payload
 | ||||||
| 	// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | 	// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | ||||||
| 	jwt := strings.Split(idToken, ".") | 	jwt := strings.Split(response.IdToken, ".") | ||||||
| 	b, err := jwtDecodeSegment(jwt[1]) | 	b, err := jwtDecodeSegment(jwt[1]) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 	data, err := simplejson.NewJson(b) | 
 | ||||||
|  | 	var email struct { | ||||||
|  | 		Email string `json:"email"` | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(b, &email) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 	email, err := data.Get("email").String() | 	if email.Email == "" { | ||||||
| 	if err != nil { | 		return "", errors.New("missing email") | ||||||
| 		return "", err |  | ||||||
| 	} | 	} | ||||||
| 	return email, nil | 	return email.Email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func jwtDecodeSegment(seg string) ([]byte, error) { | func jwtDecodeSegment(seg string) ([]byte, error) { | ||||||
|  |  | ||||||
|  | @ -2,7 +2,7 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"github.com/bitly/go-simplejson" | 	"encoding/json" | ||||||
| 	"github.com/bmizerany/assert" | 	"github.com/bmizerany/assert" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | @ -68,39 +68,61 @@ func TestGoogleProviderOverrides(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| func TestGoogleProviderGetEmailAddress(t *testing.T) { | func TestGoogleProviderGetEmailAddress(t *testing.T) { | ||||||
| 	p := newGoogleProvider() | 	p := newGoogleProvider() | ||||||
| 	j := simplejson.New() | 	body, err := json.Marshal( | ||||||
| 	j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( | 		struct { | ||||||
| 		[]byte("{\"email\": \"michael.bland@gsa.gov\"}"))) | 			IdToken string `json:"id_token"` | ||||||
| 	email, err := p.GetEmailAddress(j, "ignored access_token") | 		}{ | ||||||
|  | 			IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov"}`)), | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	email, err := p.GetEmailAddress(body, "ignored access_token") | ||||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { | func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { | ||||||
| 	p := newGoogleProvider() | 	p := newGoogleProvider() | ||||||
| 	j := simplejson.New() | 	body, err := json.Marshal( | ||||||
| 	j.Set("id_token", "ignored prefix.{\"email\": \"michael.bland@gsa.gov\"}") | 		struct { | ||||||
| 	email, err := p.GetEmailAddress(j, "ignored access_token") | 			IdToken string `json:"id_token"` | ||||||
|  | 		}{ | ||||||
|  | 			IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	email, err := p.GetEmailAddress(body, "ignored access_token") | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { | func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { | ||||||
| 	p := newGoogleProvider() | 	p := newGoogleProvider() | ||||||
| 	j := simplejson.New() | 
 | ||||||
| 	j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( | 	body, err := json.Marshal( | ||||||
| 		[]byte("{email: michael.bland@gsa.gov}"))) | 		struct { | ||||||
| 	email, err := p.GetEmailAddress(j, "ignored access_token") | 			IdToken string `json:"id_token"` | ||||||
|  | 		}{ | ||||||
|  | 			IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	email, err := p.GetEmailAddress(body, "ignored access_token") | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { | func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { | ||||||
| 	p := newGoogleProvider() | 	p := newGoogleProvider() | ||||||
| 	j := simplejson.New() | 	body, err := json.Marshal( | ||||||
| 	j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString( | 		struct { | ||||||
| 		[]byte("{\"not_email\": \"missing!\"}"))) | 			IdToken string `json:"id_token"` | ||||||
| 	email, err := p.GetEmailAddress(j, "ignored access_token") | 		}{ | ||||||
|  | 			IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	email, err := p.GetEmailAddress(body, "ignored access_token") | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,7 +1,6 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/bmizerany/assert" | 	"github.com/bmizerany/assert" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | @ -13,9 +12,7 @@ type ValidateTokenTestProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (tp *ValidateTokenTestProvider) GetEmailAddress( | func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||||
| 	unused_auth_response *simplejson.Json, |  | ||||||
| 	unused_access_token string) (string, error) { |  | ||||||
| 	return "", nil | 	return "", nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -8,7 +8,6 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/bitly/google_auth_proxy/api" | 	"github.com/bitly/google_auth_proxy/api" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -50,8 +49,7 @@ func getLinkedInHeader(access_token string) http.Header { | ||||||
| 	return header | 	return header | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json, | func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||||
| 	access_token string) (string, error) { |  | ||||||
| 	if access_token == "" { | 	if access_token == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -1,7 +1,6 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/bmizerany/assert" | 	"github.com/bmizerany/assert" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | @ -97,9 +96,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	b_url, _ := url.Parse(b.URL) | ||||||
| 	p := testLinkedInProvider(b_url.Host) | 	p := testLinkedInProvider(b_url.Host) | ||||||
| 	unused_auth_response := simplejson.New() |  | ||||||
| 
 | 
 | ||||||
| 	email, err := p.GetEmailAddress(unused_auth_response, | 	email, err := p.GetEmailAddress([]byte{}, | ||||||
| 		"imaginary_access_token") | 		"imaginary_access_token") | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "user@linkedin.com", email) | 	assert.Equal(t, "user@linkedin.com", email) | ||||||
|  | @ -111,13 +109,11 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	b_url, _ := url.Parse(b.URL) | ||||||
| 	p := testLinkedInProvider(b_url.Host) | 	p := testLinkedInProvider(b_url.Host) | ||||||
| 	unused_auth_response := simplejson.New() |  | ||||||
| 
 | 
 | ||||||
| 	// We'll trigger a request failure by using an unexpected access
 | 	// We'll trigger a request failure by using an unexpected access
 | ||||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||||
| 	// JSON to fail.
 | 	// JSON to fail.
 | ||||||
| 	email, err := p.GetEmailAddress(unused_auth_response, | 	email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") | ||||||
| 		"unexpected_access_token") |  | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
| } | } | ||||||
|  | @ -128,10 +124,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	b_url, _ := url.Parse(b.URL) | ||||||
| 	p := testLinkedInProvider(b_url.Host) | 	p := testLinkedInProvider(b_url.Host) | ||||||
| 	unused_auth_response := simplejson.New() |  | ||||||
| 
 | 
 | ||||||
| 	email, err := p.GetEmailAddress(unused_auth_response, | 	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") | ||||||
| 		"imaginary_access_token") |  | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -5,7 +5,6 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/bitly/google_auth_proxy/api" | 	"github.com/bitly/google_auth_proxy/api" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -43,8 +42,7 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider { | ||||||
| 	return &MyUsaProvider{ProviderData: p} | 	return &MyUsaProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *MyUsaProvider) GetEmailAddress(auth_response *simplejson.Json, | func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||||
| 	access_token string) (string, error) { |  | ||||||
| 	req, err := http.NewRequest("GET", | 	req, err := http.NewRequest("GET", | ||||||
| 		p.ProfileUrl.String()+"?access_token="+access_token, nil) | 		p.ProfileUrl.String()+"?access_token="+access_token, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -1,7 +1,6 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/bmizerany/assert" | 	"github.com/bmizerany/assert" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | @ -102,10 +101,8 @@ func TestMyUsaProviderGetEmailAddress(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	b_url, _ := url.Parse(b.URL) | ||||||
| 	p := testMyUsaProvider(b_url.Host) | 	p := testMyUsaProvider(b_url.Host) | ||||||
| 	unused_auth_response := simplejson.New() |  | ||||||
| 
 | 
 | ||||||
| 	email, err := p.GetEmailAddress(unused_auth_response, | 	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") | ||||||
| 		"imaginary_access_token") |  | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||||
| } | } | ||||||
|  | @ -118,13 +115,11 @@ func TestMyUsaProviderGetEmailAddressFailedRequest(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	b_url, _ := url.Parse(b.URL) | ||||||
| 	p := testMyUsaProvider(b_url.Host) | 	p := testMyUsaProvider(b_url.Host) | ||||||
| 	unused_auth_response := simplejson.New() |  | ||||||
| 
 | 
 | ||||||
| 	// We'll trigger a request failure by using an unexpected access
 | 	// We'll trigger a request failure by using an unexpected access
 | ||||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||||
| 	// JSON to fail.
 | 	// JSON to fail.
 | ||||||
| 	email, err := p.GetEmailAddress(unused_auth_response, | 	email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") | ||||||
| 		"unexpected_access_token") |  | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
| } | } | ||||||
|  | @ -135,10 +130,8 @@ func TestMyUsaProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	b_url, _ := url.Parse(b.URL) | ||||||
| 	p := testMyUsaProvider(b_url.Host) | 	p := testMyUsaProvider(b_url.Host) | ||||||
| 	unused_auth_response := simplejson.New() |  | ||||||
| 
 | 
 | ||||||
| 	email, err := p.GetEmailAddress(unused_auth_response, | 	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") | ||||||
| 		"imaginary_access_token") |  | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -6,6 +6,8 @@ import ( | ||||||
| 
 | 
 | ||||||
| type ProviderData struct { | type ProviderData struct { | ||||||
| 	ProviderName string | 	ProviderName string | ||||||
|  | 	ClientID     string | ||||||
|  | 	ClientSecret string | ||||||
| 	LoginUrl     *url.URL | 	LoginUrl     *url.URL | ||||||
| 	RedeemUrl    *url.URL | 	RedeemUrl    *url.URL | ||||||
| 	ProfileUrl   *url.URL | 	ProfileUrl   *url.URL | ||||||
|  |  | ||||||
|  | @ -0,0 +1,51 @@ | ||||||
|  | package providers | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) { | ||||||
|  | 	if code == "" { | ||||||
|  | 		err = errors.New("missing code") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	params := url.Values{} | ||||||
|  | 	params.Add("redirect_uri", redirectUrl) | ||||||
|  | 	params.Add("client_id", p.ClientID) | ||||||
|  | 	params.Add("client_secret", p.ClientSecret) | ||||||
|  | 	params.Add("code", code) | ||||||
|  | 	params.Add("grant_type", "authorization_code") | ||||||
|  | 	req, err := http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, "", err | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||||
|  | 
 | ||||||
|  | 	resp, err := http.DefaultClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, "", err | ||||||
|  | 	} | ||||||
|  | 	body, err = ioutil.ReadAll(resp.Body) | ||||||
|  | 	resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, "", err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// blindly try json and x-www-form-urlencoded
 | ||||||
|  | 	var jsonResponse struct { | ||||||
|  | 		AccessToken string `json:"access_token"` | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(body, &jsonResponse) | ||||||
|  | 	if err == nil { | ||||||
|  | 		return body, jsonResponse.AccessToken, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	v, err := url.ParseQuery(string(body)) | ||||||
|  | 	return body, v.Get("access_token"), err | ||||||
|  | } | ||||||
|  | @ -1,13 +1,9 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( |  | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Provider interface { | type Provider interface { | ||||||
| 	Data() *ProviderData | 	Data() *ProviderData | ||||||
| 	GetEmailAddress(auth_response *simplejson.Json, | 	GetEmailAddress(body []byte, access_token string) (string, error) | ||||||
| 		access_token string) (string, error) | 	Redeem(string, string) ([]byte, string, error) | ||||||
| 	ValidateToken(access_token string) bool | 	ValidateToken(access_token string) bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -17,6 +13,8 @@ func New(provider string, p *ProviderData) Provider { | ||||||
| 		return NewMyUsaProvider(p) | 		return NewMyUsaProvider(p) | ||||||
| 	case "linkedin": | 	case "linkedin": | ||||||
| 		return NewLinkedInProvider(p) | 		return NewLinkedInProvider(p) | ||||||
|  | 	case "github": | ||||||
|  | 		return NewGitHubProvider(p) | ||||||
| 	default: | 	default: | ||||||
| 		return NewGoogleProvider(p) | 		return NewGoogleProvider(p) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue