base64 cookie support
This commit is contained in:
		
							parent
							
								
									57f82ed71e
								
							
						
					
					
						commit
						cdebfd6436
					
				|  | @ -85,8 +85,8 @@ type Cipher struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewCipher returns a new aes Cipher for encrypting cookie values
 | // NewCipher returns a new aes Cipher for encrypting cookie values
 | ||||||
| func NewCipher(secret string) (*Cipher, error) { | func NewCipher(secret []byte) (*Cipher, error) { | ||||||
| 	c, err := aes.NewCipher([]byte(secret)) | 	c, err := aes.NewCipher(secret) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| package cookie | package cookie | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/base64" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bmizerany/assert" | 	"github.com/bmizerany/assert" | ||||||
|  | @ -9,7 +10,25 @@ import ( | ||||||
| func TestEncodeAndDecodeAccessToken(t *testing.T) { | func TestEncodeAndDecodeAccessToken(t *testing.T) { | ||||||
| 	const secret = "0123456789abcdefghijklmnopqrstuv" | 	const secret = "0123456789abcdefghijklmnopqrstuv" | ||||||
| 	const token = "my access token" | 	const token = "my access token" | ||||||
| 	c, err := NewCipher(secret) | 	c, err := NewCipher([]byte(secret)) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 	encoded, err := c.Encrypt(token) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 	decoded, err := c.Decrypt(encoded) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 	assert.NotEqual(t, token, encoded) | ||||||
|  | 	assert.Equal(t, token, decoded) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { | ||||||
|  | 	const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" | ||||||
|  | 	const token = "my access token" | ||||||
|  | 
 | ||||||
|  | 	secret, err := base64.URLEncoding.DecodeString(secret_b64) | ||||||
|  | 	c, err := NewCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	encoded, err := c.Encrypt(token) | 	encoded, err := c.Encrypt(token) | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								main.go
								
								
								
								
							
							
						
						
									
										2
									
								
								main.go
								
								
								
								
							|  | @ -55,7 +55,7 @@ func main() { | ||||||
| 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | ||||||
| 
 | 
 | ||||||
| 	flagSet.String("cookie-name", "_oauth2_proxy", "the name of the cookie that the oauth_proxy creates") | 	flagSet.String("cookie-name", "_oauth2_proxy", "the name of the cookie that the oauth_proxy creates") | ||||||
| 	flagSet.String("cookie-secret", "", "the seed string for secure cookies") | 	flagSet.String("cookie-secret", "", "the seed string for secure cookies (optionally base64 encoded)") | ||||||
| 	flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*") | 	flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*") | ||||||
| 	flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie") | 	flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie") | ||||||
| 	flagSet.Duration("cookie-refresh", time.Duration(0), "refresh the cookie after this duration; 0 to disable") | 	flagSet.Duration("cookie-refresh", time.Duration(0), "refresh the cookie after this duration; 0 to disable") | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/base64" | 	b64 "encoding/base64" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"html/template" | 	"html/template" | ||||||
|  | @ -164,10 +164,9 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 	var cipher *cookie.Cipher | 	var cipher *cookie.Cipher | ||||||
| 	if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { | 	if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { | ||||||
| 		var err error | 		var err error | ||||||
| 		cipher, err = cookie.NewCipher(opts.CookieSecret) | 		cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Fatal("error creating AES cipher with "+ | 			log.Fatal("cookie-secret error: ", err) | ||||||
| 				"cookie-secret ", opts.CookieSecret, ": ", err) |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -626,7 +625,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, | ||||||
| 	if len(s) != 2 || s[0] != "Basic" { | 	if len(s) != 2 || s[0] != "Basic" { | ||||||
| 		return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) | 		return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) | ||||||
| 	} | 	} | ||||||
| 	b, err := base64.StdEncoding.DecodeString(s[1]) | 	b, err := b64.StdEncoding.DecodeString(s[1]) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -427,7 +427,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { | ||||||
| 	pc_test.opts = NewOptions() | 	pc_test.opts = NewOptions() | ||||||
| 	pc_test.opts.ClientID = "bazquux" | 	pc_test.opts.ClientID = "bazquux" | ||||||
| 	pc_test.opts.ClientSecret = "xyzzyplugh" | 	pc_test.opts.ClientSecret = "xyzzyplugh" | ||||||
| 	pc_test.opts.CookieSecret = "0123456789abcdef" | 	pc_test.opts.CookieSecret = "0123456789abcdefabcd" | ||||||
| 	// First, set the CookieRefresh option so proxy.AesCipher is created,
 | 	// First, set the CookieRefresh option so proxy.AesCipher is created,
 | ||||||
| 	// needed to encrypt the access_token.
 | 	// needed to encrypt the access_token.
 | ||||||
| 	pc_test.opts.CookieRefresh = time.Hour | 	pc_test.opts.CookieRefresh = time.Hour | ||||||
|  |  | ||||||
							
								
								
									
										38
									
								
								options.go
								
								
								
								
							
							
						
						
									
										38
									
								
								options.go
								
								
								
								
							|  | @ -2,6 +2,7 @@ package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto" | 	"crypto" | ||||||
|  | 	"encoding/base64" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
|  | @ -156,17 +157,25 @@ func (o *Options) Validate() error { | ||||||
| 	if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { | 	if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { | ||||||
| 		valid_cookie_secret_size := false | 		valid_cookie_secret_size := false | ||||||
| 		for _, i := range []int{16, 24, 32} { | 		for _, i := range []int{16, 24, 32} { | ||||||
| 			if len(o.CookieSecret) == i { | 			if len(secretBytes(o.CookieSecret)) == i { | ||||||
| 				valid_cookie_secret_size = true | 				valid_cookie_secret_size = true | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | 		var decoded bool | ||||||
|  | 		if string(secretBytes(o.CookieSecret)) != o.CookieSecret { | ||||||
|  | 			decoded = true | ||||||
|  | 		} | ||||||
| 		if valid_cookie_secret_size == false { | 		if valid_cookie_secret_size == false { | ||||||
|  | 			var suffix string | ||||||
|  | 			if decoded { | ||||||
|  | 				suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) | ||||||
|  | 			} | ||||||
| 			msgs = append(msgs, fmt.Sprintf( | 			msgs = append(msgs, fmt.Sprintf( | ||||||
| 				"cookie_secret must be 16, 24, or 32 bytes "+ | 				"cookie_secret must be 16, 24, or 32 bytes "+ | ||||||
| 					"to create an AES cipher when "+ | 					"to create an AES cipher when "+ | ||||||
| 					"pass_access_token == true or "+ | 					"pass_access_token == true or "+ | ||||||
| 					"cookie_refresh != 0, but is %d bytes", | 					"cookie_refresh != 0, but is %d bytes.%s", | ||||||
| 				len(o.CookieSecret))) | 				len(secretBytes(o.CookieSecret)), suffix)) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -251,3 +260,26 @@ func parseSignatureKey(o *Options, msgs []string) []string { | ||||||
| 	} | 	} | ||||||
| 	return msgs | 	return msgs | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func addPadding(secret string) string { | ||||||
|  | 	padding := len(secret) % 4 | ||||||
|  | 	switch padding { | ||||||
|  | 	case 1: | ||||||
|  | 		return secret + "===" | ||||||
|  | 	case 2: | ||||||
|  | 		return secret + "==" | ||||||
|  | 	case 3: | ||||||
|  | 		return secret + "=" | ||||||
|  | 	default: | ||||||
|  | 		return secret | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // secretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
 | ||||||
|  | func secretBytes(secret string) []byte { | ||||||
|  | 	b, err := base64.URLEncoding.DecodeString(addPadding(secret)) | ||||||
|  | 	if err == nil { | ||||||
|  | 		return []byte(addPadding(string(b))) | ||||||
|  | 	} | ||||||
|  | 	return []byte(secret) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -160,7 +160,7 @@ func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { | ||||||
| 	o := testOptions() | 	o := testOptions() | ||||||
| 	assert.Equal(t, nil, o.Validate()) | 	assert.Equal(t, nil, o.Validate()) | ||||||
| 
 | 
 | ||||||
| 	o.CookieSecret = "0123456789abcdef" | 	o.CookieSecret = "0123456789abcdefabcd" | ||||||
| 	o.CookieRefresh = o.CookieExpire | 	o.CookieRefresh = o.CookieExpire | ||||||
| 	assert.NotEqual(t, nil, o.Validate()) | 	assert.NotEqual(t, nil, o.Validate()) | ||||||
| 
 | 
 | ||||||
|  | @ -168,6 +168,31 @@ func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, o.Validate()) | 	assert.Equal(t, nil, o.Validate()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestBase64CookieSecret(t *testing.T) { | ||||||
|  | 	o := testOptions() | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	// 32 byte, base64 (urlsafe) encoded key
 | ||||||
|  | 	o.CookieSecret = "yHBw2lh2Cvo6aI_jn_qMTr-pRAjtq0nzVgDJNb36jgQ=" | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	// 32 byte, base64 (urlsafe) encoded key, w/o padding
 | ||||||
|  | 	o.CookieSecret = "yHBw2lh2Cvo6aI_jn_qMTr-pRAjtq0nzVgDJNb36jgQ" | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	// 24 byte, base64 (urlsafe) encoded key
 | ||||||
|  | 	o.CookieSecret = "Kp33Gj-GQmYtz4zZUyUDdqQKx5_Hgkv3" | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	// 16 byte, base64 (urlsafe) encoded key
 | ||||||
|  | 	o.CookieSecret = "LFEqZYvYUwKwzn0tEuTpLA==" | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | 
 | ||||||
|  | 	// 16 byte, base64 (urlsafe) encoded key, w/o padding
 | ||||||
|  | 	o.CookieSecret = "LFEqZYvYUwKwzn0tEuTpLA" | ||||||
|  | 	assert.Equal(t, nil, o.Validate()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestValidateSignatureKey(t *testing.T) { | func TestValidateSignatureKey(t *testing.T) { | ||||||
| 	o := testOptions() | 	o := testOptions() | ||||||
| 	o.SignatureKey = "sha1:secret" | 	o.SignatureKey = "sha1:secret" | ||||||
|  |  | ||||||
|  | @ -3,11 +3,11 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strings" |  | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"log" | 	"log" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type GitHubProvider struct { | type GitHubProvider struct { | ||||||
|  | @ -64,7 +64,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { | ||||||
| 		"limit":        {"100"}, | 		"limit":        {"100"}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	endpoint := p.ValidateURL.Scheme + "://"  + p.ValidateURL.Host + "/user/orgs?" + params.Encode() | 	endpoint := p.ValidateURL.Scheme + "://" + p.ValidateURL.Host + "/user/orgs?" + params.Encode() | ||||||
| 	req, _ := http.NewRequest("GET", endpoint, nil) | 	req, _ := http.NewRequest("GET", endpoint, nil) | ||||||
| 	req.Header.Set("Accept", "application/vnd.github.v3+json") | 	req.Header.Set("Accept", "application/vnd.github.v3+json") | ||||||
| 	resp, err := http.DefaultClient.Do(req) | 	resp, err := http.DefaultClient.Do(req) | ||||||
|  |  | ||||||
|  | @ -13,9 +13,9 @@ const secret = "0123456789abcdefghijklmnopqrstuv" | ||||||
| const altSecret = "0000000000abcdefghijklmnopqrstuv" | const altSecret = "0000000000abcdefghijklmnopqrstuv" | ||||||
| 
 | 
 | ||||||
| func TestSessionStateSerialization(t *testing.T) { | func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	c, err := cookie.NewCipher(secret) | 	c, err := cookie.NewCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := cookie.NewCipher(altSecret) | 	c2, err := cookie.NewCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &SessionState{ | 	s := &SessionState{ | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue