Merge pull request #81 from 18F/access-token-refactor
Refactor pass_access_token changes from #80
This commit is contained in:
		
						commit
						9534808a0d
					
				|  | @ -24,3 +24,6 @@ _testmain.go | ||||||
| *.exe | *.exe | ||||||
| dist | dist | ||||||
| .godeps | .godeps | ||||||
|  | 
 | ||||||
|  | # Editor swap/temp files | ||||||
|  | .*.swp | ||||||
|  |  | ||||||
							
								
								
									
										31
									
								
								cookies.go
								
								
								
								
							
							
						
						
									
										31
									
								
								cookies.go
								
								
								
								
							|  | @ -97,3 +97,34 @@ func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (st | ||||||
| 
 | 
 | ||||||
| 	return string(encrypted_access_token), nil | 	return string(encrypted_access_token), nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func buildCookieValue(email string, aes_cipher cipher.Block, | ||||||
|  | 	access_token string) (string, error) { | ||||||
|  | 	if aes_cipher == nil { | ||||||
|  | 		return email, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	encoded_token, err := encodeAccessToken(aes_cipher, access_token) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return email, fmt.Errorf( | ||||||
|  | 			"error encoding access token for %s: %s", email, err) | ||||||
|  | 	} | ||||||
|  | 	return email + "|" + encoded_token, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func parseCookieValue(value string, aes_cipher cipher.Block) (email, user, | ||||||
|  | 	access_token string, err error) { | ||||||
|  | 	components := strings.Split(value, "|") | ||||||
|  | 	email = components[0] | ||||||
|  | 	user = strings.Split(email, "@")[0] | ||||||
|  | 
 | ||||||
|  | 	if aes_cipher != nil && len(components) == 2 { | ||||||
|  | 		access_token, err = decodeAccessToken(aes_cipher, components[1]) | ||||||
|  | 		if err != nil { | ||||||
|  | 			err = fmt.Errorf( | ||||||
|  | 				"error decoding access token for %s: %s", | ||||||
|  | 				email, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return email, user, access_token, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -3,6 +3,7 @@ package main | ||||||
| import ( | import ( | ||||||
| 	"crypto/aes" | 	"crypto/aes" | ||||||
| 	"github.com/bmizerany/assert" | 	"github.com/bmizerany/assert" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -21,3 +22,54 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) { | ||||||
| 	assert.NotEqual(t, access_token, encoded_token) | 	assert.NotEqual(t, access_token, encoded_token) | ||||||
| 	assert.Equal(t, access_token, decoded_token) | 	assert.Equal(t, access_token, decoded_token) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestBuildCookieValueWithoutAccessToken(t *testing.T) { | ||||||
|  | 	value, err := buildCookieValue("michael.bland@gsa.gov", nil, "") | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, "michael.bland@gsa.gov", value) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestBuildCookieValueWithAccessTokenAndNilCipher(t *testing.T) { | ||||||
|  | 	value, err := buildCookieValue("michael.bland@gsa.gov", nil, | ||||||
|  | 		"access token") | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, "michael.bland@gsa.gov", value) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestParseCookieValueWithoutAccessToken(t *testing.T) { | ||||||
|  | 	email, user, access_token, err := parseCookieValue( | ||||||
|  | 		"michael.bland@gsa.gov", nil) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||||
|  | 	assert.Equal(t, "michael.bland", user) | ||||||
|  | 	assert.Equal(t, "", access_token) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestParseCookieValueWithAccessTokenAndNilCipher(t *testing.T) { | ||||||
|  | 	email, user, access_token, err := parseCookieValue( | ||||||
|  | 		"michael.bland@gsa.gov|access_token", nil) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||||
|  | 	assert.Equal(t, "michael.bland", user) | ||||||
|  | 	assert.Equal(t, "", access_token) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestBuildAndParseCookieValueWithAccessToken(t *testing.T) { | ||||||
|  | 	aes_cipher, err := aes.NewCipher([]byte("0123456789abcdef")) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	value, err := buildCookieValue("michael.bland@gsa.gov", aes_cipher, | ||||||
|  | 		"access_token") | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 	prefix := "michael.bland@gsa.gov|" | ||||||
|  | 	if !strings.HasPrefix(value, prefix) { | ||||||
|  | 		t.Fatal("cookie value does not start with \"%s\": %s", | ||||||
|  | 			prefix, value) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	email, user, access_token, err := parseCookieValue(value, aes_cipher) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||||
|  | 	assert.Equal(t, "michael.bland", user) | ||||||
|  | 	assert.Equal(t, "access_token", access_token) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -47,7 +47,6 @@ type OauthProxy struct { | ||||||
| 	DisplayHtpasswdForm bool | 	DisplayHtpasswdForm bool | ||||||
| 	serveMux            http.Handler | 	serveMux            http.Handler | ||||||
| 	PassBasicAuth       bool | 	PassBasicAuth       bool | ||||||
| 	PassAccessToken     bool |  | ||||||
| 	AesCipher           cipher.Block | 	AesCipher           cipher.Block | ||||||
| 	skipAuthRegex       []string | 	skipAuthRegex       []string | ||||||
| 	compiledRegex       []*regexp.Regexp | 	compiledRegex       []*regexp.Regexp | ||||||
|  | @ -121,20 +120,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | ||||||
| 	log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain) | 	log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain) | ||||||
| 
 | 
 | ||||||
| 	var aes_cipher cipher.Block | 	var aes_cipher cipher.Block | ||||||
| 
 | 	if opts.PassAccessToken { | ||||||
| 	if opts.PassAccessToken == true { |  | ||||||
| 		valid_cookie_secret_size := false |  | ||||||
| 		for _, i := range []int{16, 24, 32} { |  | ||||||
| 			if len(opts.CookieSecret) == i { |  | ||||||
| 				valid_cookie_secret_size = true |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		if valid_cookie_secret_size == false { |  | ||||||
| 			log.Fatal("cookie_secret must be 16, 24, or 32 bytes " + |  | ||||||
| 				"to create an AES cipher when " + |  | ||||||
| 				"pass_access_token == true") |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		var err error | 		var err error | ||||||
| 		aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) | 		aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | @ -163,7 +149,6 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | ||||||
| 		skipAuthRegex:      opts.SkipAuthRegex, | 		skipAuthRegex:      opts.SkipAuthRegex, | ||||||
| 		compiledRegex:      opts.CompiledRegex, | 		compiledRegex:      opts.CompiledRegex, | ||||||
| 		PassBasicAuth:      opts.PassBasicAuth, | 		PassBasicAuth:      opts.PassBasicAuth, | ||||||
| 		PassAccessToken:    opts.PassAccessToken, |  | ||||||
| 		AesCipher:          aes_cipher, | 		AesCipher:          aes_cipher, | ||||||
| 		templates:          loadTemplates(opts.CustomTemplatesDir), | 		templates:          loadTemplates(opts.CustomTemplatesDir), | ||||||
| 	} | 	} | ||||||
|  | @ -440,20 +425,12 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		// set cookie, or deny
 | 		// set cookie, or deny
 | ||||||
| 		if p.Validator(email) { | 		if p.Validator(email) { | ||||||
| 			log.Printf("%s authenticating %s completed", remoteAddr, email) | 			log.Printf("%s authenticating %s completed", remoteAddr, email) | ||||||
| 			encoded_token := "" | 			value, err := buildCookieValue( | ||||||
| 			if p.PassAccessToken { | 				email, p.AesCipher, access_token) | ||||||
| 				encoded_token, err = encodeAccessToken(p.AesCipher, access_token) |  | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 					log.Printf("error encoding access token: %s", err) | 				log.Printf(err.Error()) | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			access_token = "" |  | ||||||
| 
 |  | ||||||
| 			if encoded_token != "" { |  | ||||||
| 				p.SetCookie(rw, req, email+"|"+encoded_token) |  | ||||||
| 			} else { |  | ||||||
| 				p.SetCookie(rw, req, email) |  | ||||||
| 			} | 			} | ||||||
|  | 			p.SetCookie(rw, req, value) | ||||||
| 			http.Redirect(rw, req, redirect, 302) | 			http.Redirect(rw, req, redirect, 302) | ||||||
| 			return | 			return | ||||||
| 		} else { | 		} else { | ||||||
|  | @ -467,15 +444,13 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			var value string | 			var value string | ||||||
| 			value, ok = validateCookie(cookie, p.CookieSeed) | 			value, ok = validateCookie(cookie, p.CookieSeed) | ||||||
| 			components := strings.Split(value, "|") | 			if ok { | ||||||
| 			email = components[0] | 				email, user, access_token, err = parseCookieValue( | ||||||
| 			if len(components) == 2 { | 					value, p.AesCipher) | ||||||
| 				access_token, err = decodeAccessToken(p.AesCipher, components[1]) |  | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					log.Printf("error decoding access token: %s", err) | 					log.Printf(err.Error()) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			user = strings.Split(email, "@")[0] |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -152,24 +152,25 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes | ||||||
| 	return t | 	return t | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Close(t *PassAccessTokenTest) { | func (pat_test *PassAccessTokenTest) Close() { | ||||||
| 	t.provider_server.Close() | 	pat_test.provider_server.Close() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getCallbackEndpoint(pac_test *PassAccessTokenTest) (http_code int, cookie string) { | func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, | ||||||
|  | 	cookie string) { | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
| 	req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code", | 	req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code", | ||||||
| 		strings.NewReader("")) | 		strings.NewReader("")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, "" | 		return 0, "" | ||||||
| 	} | 	} | ||||||
| 	pac_test.proxy.ServeHTTP(rw, req) | 	pat_test.proxy.ServeHTTP(rw, req) | ||||||
| 	return rw.Code, rw.HeaderMap["Set-Cookie"][0] | 	return rw.Code, rw.HeaderMap["Set-Cookie"][0] | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code int, | func (pat_test *PassAccessTokenTest) getRootEndpoint( | ||||||
| 	access_token string) { | 	cookie string) (http_code int, access_token string) { | ||||||
| 	cookie_key := pac_test.proxy.CookieKey | 	cookie_key := pat_test.proxy.CookieKey | ||||||
| 	var value string | 	var value string | ||||||
| 	key_prefix := cookie_key + "=" | 	key_prefix := cookie_key + "=" | ||||||
| 
 | 
 | ||||||
|  | @ -198,43 +199,43 @@ func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code in | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
| 	pac_test.proxy.ServeHTTP(rw, req) | 	pat_test.proxy.ServeHTTP(rw, req) | ||||||
| 	return rw.Code, rw.Body.String() | 	return rw.Code, rw.Body.String() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestForwardAccessTokenUpstream(t *testing.T) { | func TestForwardAccessTokenUpstream(t *testing.T) { | ||||||
| 	pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | 	pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | ||||||
| 		PassAccessToken: true, | 		PassAccessToken: true, | ||||||
| 	}) | 	}) | ||||||
| 	defer Close(pac_test) | 	defer pat_test.Close() | ||||||
| 
 | 
 | ||||||
| 	// A successful validation will redirect and set the auth cookie.
 | 	// A successful validation will redirect and set the auth cookie.
 | ||||||
| 	code, cookie := getCallbackEndpoint(pac_test) | 	code, cookie := pat_test.getCallbackEndpoint() | ||||||
| 	assert.Equal(t, 302, code) | 	assert.Equal(t, 302, code) | ||||||
| 	assert.NotEqual(t, nil, cookie) | 	assert.NotEqual(t, nil, cookie) | ||||||
| 
 | 
 | ||||||
| 	// Now we make a regular request; the access_token from the cookie is
 | 	// Now we make a regular request; the access_token from the cookie is
 | ||||||
| 	// forwarded as the "X-Forwarded-Access-Token" header. The token is
 | 	// forwarded as the "X-Forwarded-Access-Token" header. The token is
 | ||||||
| 	// read by the test provider server and written in the response body.
 | 	// read by the test provider server and written in the response body.
 | ||||||
| 	code, payload := getRootEndpoint(pac_test, cookie) | 	code, payload := pat_test.getRootEndpoint(cookie) | ||||||
| 	assert.Equal(t, 200, code) | 	assert.Equal(t, 200, code) | ||||||
| 	assert.Equal(t, "my_auth_token", payload) | 	assert.Equal(t, "my_auth_token", payload) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | ||||||
| 	pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | 	pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | ||||||
| 		PassAccessToken: false, | 		PassAccessToken: false, | ||||||
| 	}) | 	}) | ||||||
| 	defer Close(pac_test) | 	defer pat_test.Close() | ||||||
| 
 | 
 | ||||||
| 	// A successful validation will redirect and set the auth cookie.
 | 	// A successful validation will redirect and set the auth cookie.
 | ||||||
| 	code, cookie := getCallbackEndpoint(pac_test) | 	code, cookie := pat_test.getCallbackEndpoint() | ||||||
| 	assert.Equal(t, 302, code) | 	assert.Equal(t, 302, code) | ||||||
| 	assert.NotEqual(t, nil, cookie) | 	assert.NotEqual(t, nil, cookie) | ||||||
| 
 | 
 | ||||||
| 	// Now we make a regular request, but the access token header should
 | 	// Now we make a regular request, but the access token header should
 | ||||||
| 	// not be present.
 | 	// not be present.
 | ||||||
| 	code, payload := getRootEndpoint(pac_test, cookie) | 	code, payload := pat_test.getRootEndpoint(cookie) | ||||||
| 	assert.Equal(t, 200, code) | 	assert.Equal(t, 200, code) | ||||||
| 	assert.Equal(t, "No access token found.", payload) | 	assert.Equal(t, "No access token found.", payload) | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										17
									
								
								options.go
								
								
								
								
							
							
						
						
									
										17
									
								
								options.go
								
								
								
								
							|  | @ -117,6 +117,23 @@ func (o *Options) Validate() error { | ||||||
| 	} | 	} | ||||||
| 	msgs = parseProviderInfo(o, msgs) | 	msgs = parseProviderInfo(o, msgs) | ||||||
| 
 | 
 | ||||||
|  | 	if o.PassAccessToken { | ||||||
|  | 		valid_cookie_secret_size := false | ||||||
|  | 		for _, i := range []int{16, 24, 32} { | ||||||
|  | 			if len(o.CookieSecret) == i { | ||||||
|  | 				valid_cookie_secret_size = true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if valid_cookie_secret_size == false { | ||||||
|  | 			msgs = append(msgs, fmt.Sprintf( | ||||||
|  | 				"cookie_secret must be 16, 24, or 32 bytes "+ | ||||||
|  | 					"to create an AES cipher when "+ | ||||||
|  | 					"pass_access_token == true, "+ | ||||||
|  | 					"but is %d bytes", | ||||||
|  | 				len(o.CookieSecret))) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if len(msgs) != 0 { | 	if len(msgs) != 0 { | ||||||
| 		return fmt.Errorf("Invalid configuration:\n  %s", | 		return fmt.Errorf("Invalid configuration:\n  %s", | ||||||
| 			strings.Join(msgs, "\n  ")) | 			strings.Join(msgs, "\n  ")) | ||||||
|  |  | ||||||
|  | @ -102,3 +102,22 @@ func TestDefaultProviderApiSettings(t *testing.T) { | ||||||
| 	assert.Equal(t, "", p.ProfileUrl.String()) | 	assert.Equal(t, "", p.ProfileUrl.String()) | ||||||
| 	assert.Equal(t, "profile email", p.Scope) | 	assert.Equal(t, "profile email", p.Scope) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) { | ||||||
|  | 	o := testOptions() | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	assert.Equal(t, false, o.PassAccessToken) | ||||||
|  | 	o.PassAccessToken = true | ||||||
|  | 	o.CookieSecret = "cookie of invalid length-" | ||||||
|  | 	assert.NotEqual(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	o.CookieSecret = "16 bytes AES-128" | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	o.CookieSecret = "24 byte secret AES-192--" | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	o.CookieSecret = "32 byte secret for AES-256------" | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue