Merge pull request #98 from jehiah/github_provider_98
Add Github Provider
This commit is contained in:
		
						commit
						4de133a016
					
				| 
						 | 
					@ -101,6 +101,8 @@ Usage of google_auth_proxy:
 | 
				
			||||||
  -version=false: print version string
 | 
					  -version=false: print version string
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					See below for provider specific options
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Environment variables
 | 
					### Environment variables
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The environment variables `GOOGLE_AUTH_PROXY_CLIENT_ID`, `GOOGLE_AUTH_PROXY_CLIENT_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_DOMAIN` and `GOOGLE_AUTH_PROXY_COOKIE_EXPIRE` can be used in place of the corresponding command-line arguments.
 | 
					The environment variables `GOOGLE_AUTH_PROXY_CLIENT_ID`, `GOOGLE_AUTH_PROXY_CLIENT_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_SECRET`, `GOOGLE_AUTH_PROXY_COOKIE_DOMAIN` and `GOOGLE_AUTH_PROXY_COOKIE_EXPIRE` can be used in place of the corresponding command-line arguments.
 | 
				
			||||||
| 
						 | 
					@ -173,6 +175,10 @@ directive. Right now this includes:
 | 
				
			||||||
* `myusa` - The [MyUSA](https://alpha.my.usa.gov) authentication service
 | 
					* `myusa` - The [MyUSA](https://alpha.my.usa.gov) authentication service
 | 
				
			||||||
  ([GitHub](https://github.com/18F/myusa))
 | 
					  ([GitHub](https://github.com/18F/myusa))
 | 
				
			||||||
* `linkedin` - The [LinkedIn](https://developer.linkedin.com/docs/signin-with-linkedin) Sign In service.
 | 
					* `linkedin` - The [LinkedIn](https://developer.linkedin.com/docs/signin-with-linkedin) Sign In service.
 | 
				
			||||||
 | 
					* `github` - Via [Github][https://github.com/settings/developers] OAuth App. Also supports restricting via org and team.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    -github-org="": restrict logins to members of this organisation
 | 
				
			||||||
 | 
					    -github-team="": restrict logins to members of this team
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Adding a new Provider
 | 
					## Adding a new Provider
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,8 +10,7 @@ import (
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Request(req *http.Request) (*simplejson.Json, error) {
 | 
					func Request(req *http.Request) (*simplejson.Json, error) {
 | 
				
			||||||
	httpclient := &http.Client{}
 | 
						resp, err := http.DefaultClient.Do(req)
 | 
				
			||||||
	resp, err := httpclient.Do(req)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										3
									
								
								main.go
								
								
								
								
							
							
						
						
									
										3
									
								
								main.go
								
								
								
								
							| 
						 | 
					@ -17,6 +17,7 @@ import (
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func main() {
 | 
					func main() {
 | 
				
			||||||
 | 
						log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
 | 
				
			||||||
	flagSet := flag.NewFlagSet("google_auth_proxy", flag.ExitOnError)
 | 
						flagSet := flag.NewFlagSet("google_auth_proxy", flag.ExitOnError)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	googleAppsDomains := StringArray{}
 | 
						googleAppsDomains := StringArray{}
 | 
				
			||||||
| 
						 | 
					@ -35,6 +36,8 @@ func main() {
 | 
				
			||||||
	flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")
 | 
						flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	flagSet.Var(&googleAppsDomains, "google-apps-domain", "authenticate against the given Google apps domain (may be given multiple times)")
 | 
						flagSet.Var(&googleAppsDomains, "google-apps-domain", "authenticate against the given Google apps domain (may be given multiple times)")
 | 
				
			||||||
 | 
						flagSet.String("github-org", "", "restrict logins to members of this organisation")
 | 
				
			||||||
 | 
						flagSet.String("github-team", "", "restrict logins to members of this team")
 | 
				
			||||||
	flagSet.String("client-id", "", "the Google OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"")
 | 
						flagSet.String("client-id", "", "the Google OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"")
 | 
				
			||||||
	flagSet.String("client-secret", "", "the OAuth Client Secret")
 | 
						flagSet.String("client-secret", "", "the OAuth Client Secret")
 | 
				
			||||||
	flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)")
 | 
						flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,6 @@
 | 
				
			||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
					 | 
				
			||||||
	"crypto/aes"
 | 
						"crypto/aes"
 | 
				
			||||||
	"crypto/cipher"
 | 
						"crypto/cipher"
 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
| 
						 | 
					@ -17,7 +16,6 @@ import (
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/bitly/google_auth_proxy/api"
 | 
					 | 
				
			||||||
	"github.com/bitly/google_auth_proxy/providers"
 | 
						"github.com/bitly/google_auth_proxy/providers"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,7 +37,6 @@ type OauthProxy struct {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	redirectUrl         *url.URL // the url to receive requests at
 | 
						redirectUrl         *url.URL // the url to receive requests at
 | 
				
			||||||
	provider            providers.Provider
 | 
						provider            providers.Provider
 | 
				
			||||||
	oauthRedemptionUrl  *url.URL // endpoint to redeem the code
 | 
					 | 
				
			||||||
	oauthLoginUrl       *url.URL // to redirect the user to
 | 
						oauthLoginUrl       *url.URL // to redirect the user to
 | 
				
			||||||
	oauthValidateUrl    *url.URL // to validate the access token
 | 
						oauthValidateUrl    *url.URL // to validate the access token
 | 
				
			||||||
	oauthScope          string
 | 
						oauthScope          string
 | 
				
			||||||
| 
						 | 
					@ -147,7 +144,6 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
 | 
				
			||||||
		clientSecret:     opts.ClientSecret,
 | 
							clientSecret:     opts.ClientSecret,
 | 
				
			||||||
		oauthScope:       opts.provider.Data().Scope,
 | 
							oauthScope:       opts.provider.Data().Scope,
 | 
				
			||||||
		provider:         opts.provider,
 | 
							provider:         opts.provider,
 | 
				
			||||||
		oauthRedemptionUrl: opts.provider.Data().RedeemUrl,
 | 
					 | 
				
			||||||
		oauthLoginUrl:    opts.provider.Data().LoginUrl,
 | 
							oauthLoginUrl:    opts.provider.Data().LoginUrl,
 | 
				
			||||||
		oauthValidateUrl: opts.provider.Data().ValidateUrl,
 | 
							oauthValidateUrl: opts.provider.Data().ValidateUrl,
 | 
				
			||||||
		serveMux:         serveMux,
 | 
							serveMux:         serveMux,
 | 
				
			||||||
| 
						 | 
					@ -200,29 +196,13 @@ func (p *OauthProxy) redeemCode(host, code string) (string, string, error) {
 | 
				
			||||||
	if code == "" {
 | 
						if code == "" {
 | 
				
			||||||
		return "", "", errors.New("missing code")
 | 
							return "", "", errors.New("missing code")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	params := url.Values{}
 | 
						redirectUri := p.GetRedirectUrl(host)
 | 
				
			||||||
	params.Add("redirect_uri", p.GetRedirectUrl(host))
 | 
						body, access_token, err := p.provider.Redeem(redirectUri, code)
 | 
				
			||||||
	params.Add("client_id", p.clientID)
 | 
					 | 
				
			||||||
	params.Add("client_secret", p.clientSecret)
 | 
					 | 
				
			||||||
	params.Add("code", code)
 | 
					 | 
				
			||||||
	params.Add("grant_type", "authorization_code")
 | 
					 | 
				
			||||||
	req, err := http.NewRequest("POST", p.oauthRedemptionUrl.String(), bytes.NewBufferString(params.Encode()))
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Printf("failed building request %s", err.Error())
 | 
					 | 
				
			||||||
		return "", "", err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
					 | 
				
			||||||
	json, err := api.Request(req)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Printf("failed making request %s", err)
 | 
					 | 
				
			||||||
		return "", "", err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	access_token, err := json.Get("access_token").String()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", "", err
 | 
							return "", "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.provider.GetEmailAddress(json, access_token)
 | 
						email, err := p.provider.GetEmailAddress(body, access_token)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", "", err
 | 
							return "", "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,10 +1,10 @@
 | 
				
			||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
					 | 
				
			||||||
	"github.com/bitly/google_auth_proxy/providers"
 | 
						"github.com/bitly/google_auth_proxy/providers"
 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
| 
						 | 
					@ -15,6 +15,11 @@ import (
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestNewReverseProxy(t *testing.T) {
 | 
					func TestNewReverseProxy(t *testing.T) {
 | 
				
			||||||
	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
		w.WriteHeader(200)
 | 
							w.WriteHeader(200)
 | 
				
			||||||
| 
						 | 
					@ -89,8 +94,7 @@ type TestProvider struct {
 | 
				
			||||||
	ValidToken   bool
 | 
						ValidToken   bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
 | 
					func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
				
			||||||
	unused_access_token string) (string, error) {
 | 
					 | 
				
			||||||
	return tp.EmailAddress, nil
 | 
						return tp.EmailAddress, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -113,16 +117,15 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.provider_server = httptest.NewServer(
 | 
						t.provider_server = httptest.NewServer(
 | 
				
			||||||
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
							http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
								log.Printf("%#v", r)
 | 
				
			||||||
			url := r.URL
 | 
								url := r.URL
 | 
				
			||||||
			payload := ""
 | 
								payload := ""
 | 
				
			||||||
			switch url.Path {
 | 
								switch url.Path {
 | 
				
			||||||
			case "/oauth/token":
 | 
								case "/oauth/token":
 | 
				
			||||||
				payload = `{"access_token": "my_auth_token"}`
 | 
									payload = `{"access_token": "my_auth_token"}`
 | 
				
			||||||
			default:
 | 
								default:
 | 
				
			||||||
				token_header := r.Header["X-Forwarded-Access-Token"]
 | 
									payload = r.Header.Get("X-Forwarded-Access-Token")
 | 
				
			||||||
				if len(token_header) != 0 {
 | 
									if payload == "" {
 | 
				
			||||||
					payload = token_header[0]
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					payload = "No access token found."
 | 
										payload = "No access token found."
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -189,8 +192,7 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
 | 
				
			||||||
	return rw.Code, rw.HeaderMap["Set-Cookie"][0]
 | 
						return rw.Code, rw.HeaderMap["Set-Cookie"][0]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (pat_test *PassAccessTokenTest) getRootEndpoint(
 | 
					func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
 | 
				
			||||||
	cookie string) (http_code int, access_token string) {
 | 
					 | 
				
			||||||
	cookie_key := pat_test.proxy.CookieKey
 | 
						cookie_key := pat_test.proxy.CookieKey
 | 
				
			||||||
	var value string
 | 
						var value string
 | 
				
			||||||
	key_prefix := cookie_key + "="
 | 
						key_prefix := cookie_key + "="
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,8 @@ type Options struct {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	AuthenticatedEmailsFile string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
 | 
						AuthenticatedEmailsFile string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
 | 
				
			||||||
	GoogleAppsDomains       []string `flag:"google-apps-domain" cfg:"google_apps_domains"`
 | 
						GoogleAppsDomains       []string `flag:"google-apps-domain" cfg:"google_apps_domains"`
 | 
				
			||||||
 | 
						GitHubOrg               string   `flag:"github-org" cfg:"github_org"`
 | 
				
			||||||
 | 
						GitHubTeam              string   `flag:"github-team" cfg:"github_team"`
 | 
				
			||||||
	HtpasswdFile            string   `flag:"htpasswd-file" cfg:"htpasswd_file"`
 | 
						HtpasswdFile            string   `flag:"htpasswd-file" cfg:"htpasswd_file"`
 | 
				
			||||||
	DisplayHtpasswdForm     bool     `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"`
 | 
						DisplayHtpasswdForm     bool     `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"`
 | 
				
			||||||
	CustomTemplatesDir      string   `flag:"custom-templates-dir" cfg:"custom_templates_dir"`
 | 
						CustomTemplatesDir      string   `flag:"custom-templates-dir" cfg:"custom_templates_dir"`
 | 
				
			||||||
| 
						 | 
					@ -153,11 +155,16 @@ func (o *Options) Validate() error {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func parseProviderInfo(o *Options, msgs []string) []string {
 | 
					func parseProviderInfo(o *Options, msgs []string) []string {
 | 
				
			||||||
	p := &providers.ProviderData{Scope: o.Scope}
 | 
						p := &providers.ProviderData{Scope: o.Scope, ClientID: o.ClientID, ClientSecret: o.ClientSecret}
 | 
				
			||||||
	p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs)
 | 
						p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs)
 | 
				
			||||||
	p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs)
 | 
						p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs)
 | 
				
			||||||
	p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs)
 | 
						p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs)
 | 
				
			||||||
	p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs)
 | 
						p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	o.provider = providers.New(o.Provider, p)
 | 
						o.provider = providers.New(o.Provider, p)
 | 
				
			||||||
 | 
						switch p := o.provider.(type) {
 | 
				
			||||||
 | 
						case *providers.GitHubProvider:
 | 
				
			||||||
 | 
							p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return msgs
 | 
						return msgs
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,136 @@
 | 
				
			||||||
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GitHubProvider struct {
 | 
				
			||||||
 | 
						*ProviderData
 | 
				
			||||||
 | 
						Org  string
 | 
				
			||||||
 | 
						Team string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewGitHubProvider(p *ProviderData) *GitHubProvider {
 | 
				
			||||||
 | 
						p.ProviderName = "GitHub"
 | 
				
			||||||
 | 
						if p.LoginUrl.String() == "" {
 | 
				
			||||||
 | 
							p.LoginUrl = &url.URL{
 | 
				
			||||||
 | 
								Scheme: "https",
 | 
				
			||||||
 | 
								Host:   "github.com",
 | 
				
			||||||
 | 
								Path:   "/login/oauth/authorize",
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if p.RedeemUrl.String() == "" {
 | 
				
			||||||
 | 
							p.RedeemUrl = &url.URL{
 | 
				
			||||||
 | 
								Scheme: "https",
 | 
				
			||||||
 | 
								Host:   "github.com",
 | 
				
			||||||
 | 
								Path:   "/login/oauth/access_token",
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if p.ValidateUrl.String() == "" {
 | 
				
			||||||
 | 
							p.ValidateUrl = &url.URL{
 | 
				
			||||||
 | 
								Scheme: "https",
 | 
				
			||||||
 | 
								Host:   "api.github.com",
 | 
				
			||||||
 | 
								Path:   "/user/emails",
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if p.Scope == "" {
 | 
				
			||||||
 | 
							p.Scope = "user:email"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &GitHubProvider{ProviderData: p}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (p *GitHubProvider) SetOrgTeam(org, team string) {
 | 
				
			||||||
 | 
						p.Org = org
 | 
				
			||||||
 | 
						p.Team = team
 | 
				
			||||||
 | 
						if org != "" || team != "" {
 | 
				
			||||||
 | 
							p.Scope += " read:org"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var teams []struct {
 | 
				
			||||||
 | 
							Name string `json:"name"`
 | 
				
			||||||
 | 
							Slug string `json:"slug"`
 | 
				
			||||||
 | 
							Org  struct {
 | 
				
			||||||
 | 
								Login string `json:"login"`
 | 
				
			||||||
 | 
							} `json:"organization"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						params := url.Values{
 | 
				
			||||||
 | 
							"access_token": {accessToken},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req, _ := http.NewRequest("GET", "https://api.github.com/user/teams?"+params.Encode(), nil)
 | 
				
			||||||
 | 
						req.Header.Set("Accept", "application/vnd.github.moondragon+json")
 | 
				
			||||||
 | 
						resp, err := http.DefaultClient.Do(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := json.Unmarshal(body, &teams); err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, team := range teams {
 | 
				
			||||||
 | 
							if p.Org == team.Org.Login {
 | 
				
			||||||
 | 
								if p.Team == "" || p.Team == team.Slug {
 | 
				
			||||||
 | 
									return true, nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var emails []struct {
 | 
				
			||||||
 | 
							Email   string `json:"email"`
 | 
				
			||||||
 | 
							Primary bool   `json:"primary"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						params := url.Values{
 | 
				
			||||||
 | 
							"access_token": {access_token},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// if we require an Org or Team, check that first
 | 
				
			||||||
 | 
						if p.Org != "" || p.Team != "" {
 | 
				
			||||||
 | 
							if ok, err := p.hasOrgAndTeam(access_token); err != nil || !ok {
 | 
				
			||||||
 | 
								return "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp, err := http.DefaultClient.Get("https://api.github.com/user/emails?" + params.Encode())
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						body, err = ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := json.Unmarshal(body, &emails); err != nil {
 | 
				
			||||||
 | 
							return "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, email := range emails {
 | 
				
			||||||
 | 
							if email.Primary {
 | 
				
			||||||
 | 
								return email.Email, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return "", nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *GitHubProvider) ValidateToken(access_token string) bool {
 | 
				
			||||||
 | 
						return validateToken(p, access_token, nil)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -2,10 +2,10 @@ package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type GoogleProvider struct {
 | 
					type GoogleProvider struct {
 | 
				
			||||||
| 
						 | 
					@ -35,28 +35,34 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
 | 
				
			||||||
	return &GoogleProvider{ProviderData: p}
 | 
						return &GoogleProvider{ProviderData: p}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *GoogleProvider) GetEmailAddress(auth_response *simplejson.Json,
 | 
					func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
				
			||||||
	unused_access_token string) (string, error) {
 | 
						var response struct {
 | 
				
			||||||
	idToken, err := auth_response.Get("id_token").String()
 | 
							IdToken string `json:"id_token"`
 | 
				
			||||||
	if err != nil {
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := json.Unmarshal(body, &response); err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// id_token is a base64 encode ID token payload
 | 
						// id_token is a base64 encode ID token payload
 | 
				
			||||||
	// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | 
						// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | 
				
			||||||
	jwt := strings.Split(idToken, ".")
 | 
						jwt := strings.Split(response.IdToken, ".")
 | 
				
			||||||
	b, err := jwtDecodeSegment(jwt[1])
 | 
						b, err := jwtDecodeSegment(jwt[1])
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	data, err := simplejson.NewJson(b)
 | 
					
 | 
				
			||||||
 | 
						var email struct {
 | 
				
			||||||
 | 
							Email string `json:"email"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = json.Unmarshal(b, &email)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	email, err := data.Get("email").String()
 | 
						if email.Email == "" {
 | 
				
			||||||
	if err != nil {
 | 
							return "", errors.New("missing email")
 | 
				
			||||||
		return "", err
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return email, nil
 | 
						return email.Email, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func jwtDecodeSegment(seg string) ([]byte, error) {
 | 
					func jwtDecodeSegment(seg string) ([]byte, error) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@ package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
						"encoding/json"
 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
| 
						 | 
					@ -68,39 +68,61 @@ func TestGoogleProviderOverrides(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestGoogleProviderGetEmailAddress(t *testing.T) {
 | 
					func TestGoogleProviderGetEmailAddress(t *testing.T) {
 | 
				
			||||||
	p := newGoogleProvider()
 | 
						p := newGoogleProvider()
 | 
				
			||||||
	j := simplejson.New()
 | 
						body, err := json.Marshal(
 | 
				
			||||||
	j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString(
 | 
							struct {
 | 
				
			||||||
		[]byte("{\"email\": \"michael.bland@gsa.gov\"}")))
 | 
								IdToken string `json:"id_token"`
 | 
				
			||||||
	email, err := p.GetEmailAddress(j, "ignored access_token")
 | 
							}{
 | 
				
			||||||
 | 
								IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov"}`)),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(body, "ignored access_token")
 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
						assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
 | 
					func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
 | 
				
			||||||
	p := newGoogleProvider()
 | 
						p := newGoogleProvider()
 | 
				
			||||||
	j := simplejson.New()
 | 
						body, err := json.Marshal(
 | 
				
			||||||
	j.Set("id_token", "ignored prefix.{\"email\": \"michael.bland@gsa.gov\"}")
 | 
							struct {
 | 
				
			||||||
	email, err := p.GetEmailAddress(j, "ignored access_token")
 | 
								IdToken string `json:"id_token"`
 | 
				
			||||||
 | 
							}{
 | 
				
			||||||
 | 
								IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(body, "ignored access_token")
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
 | 
					func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
 | 
				
			||||||
	p := newGoogleProvider()
 | 
						p := newGoogleProvider()
 | 
				
			||||||
	j := simplejson.New()
 | 
					
 | 
				
			||||||
	j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString(
 | 
						body, err := json.Marshal(
 | 
				
			||||||
		[]byte("{email: michael.bland@gsa.gov}")))
 | 
							struct {
 | 
				
			||||||
	email, err := p.GetEmailAddress(j, "ignored access_token")
 | 
								IdToken string `json:"id_token"`
 | 
				
			||||||
 | 
							}{
 | 
				
			||||||
 | 
								IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(body, "ignored access_token")
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
 | 
					func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
 | 
				
			||||||
	p := newGoogleProvider()
 | 
						p := newGoogleProvider()
 | 
				
			||||||
	j := simplejson.New()
 | 
						body, err := json.Marshal(
 | 
				
			||||||
	j.Set("id_token", "ignored prefix."+base64.URLEncoding.EncodeToString(
 | 
							struct {
 | 
				
			||||||
		[]byte("{\"not_email\": \"missing!\"}")))
 | 
								IdToken string `json:"id_token"`
 | 
				
			||||||
	email, err := p.GetEmailAddress(j, "ignored access_token")
 | 
							}{
 | 
				
			||||||
 | 
								IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(body, "ignored access_token")
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,6 @@
 | 
				
			||||||
package providers
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
					 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
| 
						 | 
					@ -13,9 +12,7 @@ type ValidateTokenTestProvider struct {
 | 
				
			||||||
	*ProviderData
 | 
						*ProviderData
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (tp *ValidateTokenTestProvider) GetEmailAddress(
 | 
					func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
				
			||||||
	unused_auth_response *simplejson.Json,
 | 
					 | 
				
			||||||
	unused_access_token string) (string, error) {
 | 
					 | 
				
			||||||
	return "", nil
 | 
						return "", nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,7 +8,6 @@ import (
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
					 | 
				
			||||||
	"github.com/bitly/google_auth_proxy/api"
 | 
						"github.com/bitly/google_auth_proxy/api"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,8 +49,7 @@ func getLinkedInHeader(access_token string) http.Header {
 | 
				
			||||||
	return header
 | 
						return header
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
 | 
					func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
				
			||||||
	access_token string) (string, error) {
 | 
					 | 
				
			||||||
	if access_token == "" {
 | 
						if access_token == "" {
 | 
				
			||||||
		return "", errors.New("missing access token")
 | 
							return "", errors.New("missing access token")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,6 @@
 | 
				
			||||||
package providers
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
					 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
| 
						 | 
					@ -97,9 +96,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testLinkedInProvider(b_url.Host)
 | 
						p := testLinkedInProvider(b_url.Host)
 | 
				
			||||||
	unused_auth_response := simplejson.New()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.GetEmailAddress(unused_auth_response,
 | 
						email, err := p.GetEmailAddress([]byte{},
 | 
				
			||||||
		"imaginary_access_token")
 | 
							"imaginary_access_token")
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "user@linkedin.com", email)
 | 
						assert.Equal(t, "user@linkedin.com", email)
 | 
				
			||||||
| 
						 | 
					@ -111,13 +109,11 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testLinkedInProvider(b_url.Host)
 | 
						p := testLinkedInProvider(b_url.Host)
 | 
				
			||||||
	unused_auth_response := simplejson.New()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// We'll trigger a request failure by using an unexpected access
 | 
						// We'll trigger a request failure by using an unexpected access
 | 
				
			||||||
	// token. Alternatively, we could allow the parsing of the payload as
 | 
						// token. Alternatively, we could allow the parsing of the payload as
 | 
				
			||||||
	// JSON to fail.
 | 
						// JSON to fail.
 | 
				
			||||||
	email, err := p.GetEmailAddress(unused_auth_response,
 | 
						email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token")
 | 
				
			||||||
		"unexpected_access_token")
 | 
					 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -128,10 +124,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testLinkedInProvider(b_url.Host)
 | 
						p := testLinkedInProvider(b_url.Host)
 | 
				
			||||||
	unused_auth_response := simplejson.New()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.GetEmailAddress(unused_auth_response,
 | 
						email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
 | 
				
			||||||
		"imaginary_access_token")
 | 
					 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,7 +5,6 @@ import (
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
					 | 
				
			||||||
	"github.com/bitly/google_auth_proxy/api"
 | 
						"github.com/bitly/google_auth_proxy/api"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,8 +42,7 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider {
 | 
				
			||||||
	return &MyUsaProvider{ProviderData: p}
 | 
						return &MyUsaProvider{ProviderData: p}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *MyUsaProvider) GetEmailAddress(auth_response *simplejson.Json,
 | 
					func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
				
			||||||
	access_token string) (string, error) {
 | 
					 | 
				
			||||||
	req, err := http.NewRequest("GET",
 | 
						req, err := http.NewRequest("GET",
 | 
				
			||||||
		p.ProfileUrl.String()+"?access_token="+access_token, nil)
 | 
							p.ProfileUrl.String()+"?access_token="+access_token, nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,6 @@
 | 
				
			||||||
package providers
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
					 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
| 
						 | 
					@ -102,10 +101,8 @@ func TestMyUsaProviderGetEmailAddress(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testMyUsaProvider(b_url.Host)
 | 
						p := testMyUsaProvider(b_url.Host)
 | 
				
			||||||
	unused_auth_response := simplejson.New()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.GetEmailAddress(unused_auth_response,
 | 
						email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
 | 
				
			||||||
		"imaginary_access_token")
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
						assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -118,13 +115,11 @@ func TestMyUsaProviderGetEmailAddressFailedRequest(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testMyUsaProvider(b_url.Host)
 | 
						p := testMyUsaProvider(b_url.Host)
 | 
				
			||||||
	unused_auth_response := simplejson.New()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// We'll trigger a request failure by using an unexpected access
 | 
						// We'll trigger a request failure by using an unexpected access
 | 
				
			||||||
	// token. Alternatively, we could allow the parsing of the payload as
 | 
						// token. Alternatively, we could allow the parsing of the payload as
 | 
				
			||||||
	// JSON to fail.
 | 
						// JSON to fail.
 | 
				
			||||||
	email, err := p.GetEmailAddress(unused_auth_response,
 | 
						email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token")
 | 
				
			||||||
		"unexpected_access_token")
 | 
					 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -135,10 +130,8 @@ func TestMyUsaProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testMyUsaProvider(b_url.Host)
 | 
						p := testMyUsaProvider(b_url.Host)
 | 
				
			||||||
	unused_auth_response := simplejson.New()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.GetEmailAddress(unused_auth_response,
 | 
						email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
 | 
				
			||||||
		"imaginary_access_token")
 | 
					 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,6 +6,8 @@ import (
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ProviderData struct {
 | 
					type ProviderData struct {
 | 
				
			||||||
	ProviderName string
 | 
						ProviderName string
 | 
				
			||||||
 | 
						ClientID     string
 | 
				
			||||||
 | 
						ClientSecret string
 | 
				
			||||||
	LoginUrl     *url.URL
 | 
						LoginUrl     *url.URL
 | 
				
			||||||
	RedeemUrl    *url.URL
 | 
						RedeemUrl    *url.URL
 | 
				
			||||||
	ProfileUrl   *url.URL
 | 
						ProfileUrl   *url.URL
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,51 @@
 | 
				
			||||||
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) {
 | 
				
			||||||
 | 
						if code == "" {
 | 
				
			||||||
 | 
							err = errors.New("missing code")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						params := url.Values{}
 | 
				
			||||||
 | 
						params.Add("redirect_uri", redirectUrl)
 | 
				
			||||||
 | 
						params.Add("client_id", p.ClientID)
 | 
				
			||||||
 | 
						params.Add("client_secret", p.ClientSecret)
 | 
				
			||||||
 | 
						params.Add("code", code)
 | 
				
			||||||
 | 
						params.Add("grant_type", "authorization_code")
 | 
				
			||||||
 | 
						req, err := http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode()))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp, err := http.DefaultClient.Do(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						body, err = ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// blindly try json and x-www-form-urlencoded
 | 
				
			||||||
 | 
						var jsonResponse struct {
 | 
				
			||||||
 | 
							AccessToken string `json:"access_token"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = json.Unmarshal(body, &jsonResponse)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							return body, jsonResponse.AccessToken, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						v, err := url.ParseQuery(string(body))
 | 
				
			||||||
 | 
						return body, v.Get("access_token"), err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -1,13 +1,9 @@
 | 
				
			||||||
package providers
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type Provider interface {
 | 
					type Provider interface {
 | 
				
			||||||
	Data() *ProviderData
 | 
						Data() *ProviderData
 | 
				
			||||||
	GetEmailAddress(auth_response *simplejson.Json,
 | 
						GetEmailAddress(body []byte, access_token string) (string, error)
 | 
				
			||||||
		access_token string) (string, error)
 | 
						Redeem(string, string) ([]byte, string, error)
 | 
				
			||||||
	ValidateToken(access_token string) bool
 | 
						ValidateToken(access_token string) bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,6 +13,8 @@ func New(provider string, p *ProviderData) Provider {
 | 
				
			||||||
		return NewMyUsaProvider(p)
 | 
							return NewMyUsaProvider(p)
 | 
				
			||||||
	case "linkedin":
 | 
						case "linkedin":
 | 
				
			||||||
		return NewLinkedInProvider(p)
 | 
							return NewLinkedInProvider(p)
 | 
				
			||||||
 | 
						case "github":
 | 
				
			||||||
 | 
							return NewGitHubProvider(p)
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return NewGoogleProvider(p)
 | 
							return NewGoogleProvider(p)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue