Merge pull request #117 from jehiah/always_refresh_117
Google - continually use refresh token
This commit is contained in:
		
						commit
						01c9d04feb
					
				|  | @ -3,6 +3,7 @@ package api | |||
| import ( | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/bitly/go-simplejson" | ||||
|  | @ -11,10 +12,12 @@ import ( | |||
| func Request(req *http.Request) (*simplejson.Json, error) { | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		log.Printf("%s %s %s", req.Method, req.URL, err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  |  | |||
|  | @ -0,0 +1,128 @@ | |||
| package cookie | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/sha1" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
 | ||||
| // additionally, the 'value' is encrypted so it's opaque to the browser
 | ||||
| 
 | ||||
| // Validate ensures a cookie is properly signed
 | ||||
| func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) { | ||||
| 	// value, timestamp, sig
 | ||||
| 	parts := strings.Split(cookie.Value, "|") | ||||
| 	if len(parts) != 3 { | ||||
| 		return | ||||
| 	} | ||||
| 	sig := cookieSignature(seed, cookie.Name, parts[0], parts[1]) | ||||
| 	if checkHmac(parts[2], sig) { | ||||
| 		ts, err := strconv.Atoi(parts[1]) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		// The expiration timestamp set when the cookie was created
 | ||||
| 		// isn't sent back by the browser. Hence, we check whether the
 | ||||
| 		// creation timestamp stored in the cookie falls within the
 | ||||
| 		// window defined by (Now()-expiration, Now()].
 | ||||
| 		t = time.Unix(int64(ts), 0) | ||||
| 		if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { | ||||
| 			// it's a valid cookie. now get the contents
 | ||||
| 			rawValue, err := base64.URLEncoding.DecodeString(parts[0]) | ||||
| 			if err == nil { | ||||
| 				value = string(rawValue) | ||||
| 				ok = true | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // SignedValue returns a cookie that is signed and can later be checked with Validate
 | ||||
| func SignedValue(seed string, key string, value string, now time.Time) string { | ||||
| 	encodedValue := base64.URLEncoding.EncodeToString([]byte(value)) | ||||
| 	timeStr := fmt.Sprintf("%d", now.Unix()) | ||||
| 	sig := cookieSignature(seed, key, encodedValue, timeStr) | ||||
| 	cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) | ||||
| 	return cookieVal | ||||
| } | ||||
| 
 | ||||
| func cookieSignature(args ...string) string { | ||||
| 	h := hmac.New(sha1.New, []byte(args[0])) | ||||
| 	for _, arg := range args[1:] { | ||||
| 		h.Write([]byte(arg)) | ||||
| 	} | ||||
| 	var b []byte | ||||
| 	b = h.Sum(b) | ||||
| 	return base64.URLEncoding.EncodeToString(b) | ||||
| } | ||||
| 
 | ||||
| func checkHmac(input, expected string) bool { | ||||
| 	inputMAC, err1 := base64.URLEncoding.DecodeString(input) | ||||
| 	if err1 == nil { | ||||
| 		expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) | ||||
| 		if err2 == nil { | ||||
| 			return hmac.Equal(inputMAC, expectedMAC) | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // Cipher provides methods to encrypt and decrypt cookie values
 | ||||
| type Cipher struct { | ||||
| 	cipher.Block | ||||
| } | ||||
| 
 | ||||
| // NewCipher returns a new aes Cipher for encrypting cookie values
 | ||||
| func NewCipher(secret string) (*Cipher, error) { | ||||
| 	c, err := aes.NewCipher([]byte(secret)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Cipher{Block: c}, err | ||||
| } | ||||
| 
 | ||||
| // Encrypt a value for use in a cookie
 | ||||
| func (c *Cipher) Encrypt(value string) (string, error) { | ||||
| 	ciphertext := make([]byte, aes.BlockSize+len(value)) | ||||
| 	iv := ciphertext[:aes.BlockSize] | ||||
| 	if _, err := io.ReadFull(rand.Reader, iv); err != nil { | ||||
| 		return "", fmt.Errorf("failed to create initialization vector %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	stream := cipher.NewCFBEncrypter(c.Block, iv) | ||||
| 	stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value)) | ||||
| 	return base64.StdEncoding.EncodeToString(ciphertext), nil | ||||
| } | ||||
| 
 | ||||
| // Decrypt a value from a cookie to it's original string
 | ||||
| func (c *Cipher) Decrypt(s string) (string, error) { | ||||
| 	encrypted, err := base64.StdEncoding.DecodeString(s) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("failed to decrypt cookie value %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(encrypted) < aes.BlockSize { | ||||
| 		return "", fmt.Errorf("encrypted cookie value should be "+ | ||||
| 			"at least %d bytes, but is only %d bytes", | ||||
| 			aes.BlockSize, len(encrypted)) | ||||
| 	} | ||||
| 
 | ||||
| 	iv := encrypted[:aes.BlockSize] | ||||
| 	encrypted = encrypted[aes.BlockSize:] | ||||
| 	stream := cipher.NewCFBDecrypter(c.Block, iv) | ||||
| 	stream.XORKeyStream(encrypted, encrypted) | ||||
| 
 | ||||
| 	return string(encrypted), nil | ||||
| } | ||||
|  | @ -0,0 +1,23 @@ | |||
| package cookie | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/bmizerany/assert" | ||||
| ) | ||||
| 
 | ||||
| func TestEncodeAndDecodeAccessToken(t *testing.T) { | ||||
| 	const secret = "0123456789abcdefghijklmnopqrstuv" | ||||
| 	const token = "my access token" | ||||
| 	c, err := NewCipher(secret) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	encoded, err := c.Encrypt(token) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	decoded, err := c.Decrypt(encoded) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	assert.NotEqual(t, token, encoded) | ||||
| 	assert.Equal(t, token, decoded) | ||||
| } | ||||
							
								
								
									
										140
									
								
								cookies.go
								
								
								
								
							
							
						
						
									
										140
									
								
								cookies.go
								
								
								
								
							|  | @ -1,140 +0,0 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"crypto/hmac" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/sha1" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func validateCookie(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) { | ||||
| 	// value, timestamp, sig
 | ||||
| 	parts := strings.Split(cookie.Value, "|") | ||||
| 	if len(parts) != 3 { | ||||
| 		return | ||||
| 	} | ||||
| 	sig := cookieSignature(seed, cookie.Name, parts[0], parts[1]) | ||||
| 	if checkHmac(parts[2], sig) { | ||||
| 		ts, err := strconv.Atoi(parts[1]) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		// The expiration timestamp set when the cookie was created
 | ||||
| 		// isn't sent back by the browser. Hence, we check whether the
 | ||||
| 		// creation timestamp stored in the cookie falls within the
 | ||||
| 		// window defined by (Now()-expiration, Now()].
 | ||||
| 		t = time.Unix(int64(ts), 0) | ||||
| 		if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { | ||||
| 			// it's a valid cookie. now get the contents
 | ||||
| 			rawValue, err := base64.URLEncoding.DecodeString(parts[0]) | ||||
| 			if err == nil { | ||||
| 				value = string(rawValue) | ||||
| 				ok = true | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func signedCookieValue(seed string, key string, value string, now time.Time) string { | ||||
| 	encodedValue := base64.URLEncoding.EncodeToString([]byte(value)) | ||||
| 	timeStr := fmt.Sprintf("%d", now.Unix()) | ||||
| 	sig := cookieSignature(seed, key, encodedValue, timeStr) | ||||
| 	cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) | ||||
| 	return cookieVal | ||||
| } | ||||
| 
 | ||||
| func cookieSignature(args ...string) string { | ||||
| 	h := hmac.New(sha1.New, []byte(args[0])) | ||||
| 	for _, arg := range args[1:] { | ||||
| 		h.Write([]byte(arg)) | ||||
| 	} | ||||
| 	var b []byte | ||||
| 	b = h.Sum(b) | ||||
| 	return base64.URLEncoding.EncodeToString(b) | ||||
| } | ||||
| 
 | ||||
| func checkHmac(input, expected string) bool { | ||||
| 	inputMAC, err1 := base64.URLEncoding.DecodeString(input) | ||||
| 	if err1 == nil { | ||||
| 		expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) | ||||
| 		if err2 == nil { | ||||
| 			return hmac.Equal(inputMAC, expectedMAC) | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func encodeAccessToken(aes_cipher cipher.Block, access_token string) (string, error) { | ||||
| 	ciphertext := make([]byte, aes.BlockSize+len(access_token)) | ||||
| 	iv := ciphertext[:aes.BlockSize] | ||||
| 	if _, err := io.ReadFull(rand.Reader, iv); err != nil { | ||||
| 		return "", fmt.Errorf("failed to create access code initialization vector") | ||||
| 	} | ||||
| 
 | ||||
| 	stream := cipher.NewCFBEncrypter(aes_cipher, iv) | ||||
| 	stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(access_token)) | ||||
| 	return base64.StdEncoding.EncodeToString(ciphertext), nil | ||||
| } | ||||
| 
 | ||||
| func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (string, error) { | ||||
| 	encrypted_access_token, err := base64.StdEncoding.DecodeString( | ||||
| 		encoded_access_token) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("failed to decode access token") | ||||
| 	} | ||||
| 
 | ||||
| 	if len(encrypted_access_token) < aes.BlockSize { | ||||
| 		return "", fmt.Errorf("encrypted access token should be "+ | ||||
| 			"at least %d bytes, but is only %d bytes", | ||||
| 			aes.BlockSize, len(encrypted_access_token)) | ||||
| 	} | ||||
| 
 | ||||
| 	iv := encrypted_access_token[:aes.BlockSize] | ||||
| 	encrypted_access_token = encrypted_access_token[aes.BlockSize:] | ||||
| 	stream := cipher.NewCFBDecrypter(aes_cipher, iv) | ||||
| 	stream.XORKeyStream(encrypted_access_token, encrypted_access_token) | ||||
| 
 | ||||
| 	return string(encrypted_access_token), nil | ||||
| } | ||||
| 
 | ||||
| func buildCookieValue(email string, aes_cipher cipher.Block, | ||||
| 	access_token string) (string, error) { | ||||
| 	if aes_cipher == nil { | ||||
| 		return email, nil | ||||
| 	} | ||||
| 
 | ||||
| 	encoded_token, err := encodeAccessToken(aes_cipher, access_token) | ||||
| 	if err != nil { | ||||
| 		return email, fmt.Errorf( | ||||
| 			"error encoding access token for %s: %s", email, err) | ||||
| 	} | ||||
| 	return email + "|" + encoded_token, nil | ||||
| } | ||||
| 
 | ||||
| func parseCookieValue(value string, aes_cipher cipher.Block) (email, user, | ||||
| 	access_token string, err error) { | ||||
| 	components := strings.Split(value, "|") | ||||
| 	email = components[0] | ||||
| 	user = strings.Split(email, "@")[0] | ||||
| 
 | ||||
| 	if aes_cipher != nil && len(components) == 2 { | ||||
| 		access_token, err = decodeAccessToken(aes_cipher, components[1]) | ||||
| 		if err != nil { | ||||
| 			err = fmt.Errorf( | ||||
| 				"error decoding access token for %s: %s", | ||||
| 				email, err) | ||||
| 		} | ||||
| 	} | ||||
| 	return email, user, access_token, err | ||||
| } | ||||
|  | @ -1,75 +0,0 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/aes" | ||||
| 	"github.com/bmizerany/assert" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestEncodeAndDecodeAccessToken(t *testing.T) { | ||||
| 	const key = "0123456789abcdefghijklmnopqrstuv" | ||||
| 	const access_token = "my access token" | ||||
| 	c, err := aes.NewCipher([]byte(key)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	encoded_token, err := encodeAccessToken(c, access_token) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	decoded_token, err := decodeAccessToken(c, encoded_token) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	assert.NotEqual(t, access_token, encoded_token) | ||||
| 	assert.Equal(t, access_token, decoded_token) | ||||
| } | ||||
| 
 | ||||
| func TestBuildCookieValueWithoutAccessToken(t *testing.T) { | ||||
| 	value, err := buildCookieValue("michael.bland@gsa.gov", nil, "") | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", value) | ||||
| } | ||||
| 
 | ||||
| func TestBuildCookieValueWithAccessTokenAndNilCipher(t *testing.T) { | ||||
| 	value, err := buildCookieValue("michael.bland@gsa.gov", nil, | ||||
| 		"access token") | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", value) | ||||
| } | ||||
| 
 | ||||
| func TestParseCookieValueWithoutAccessToken(t *testing.T) { | ||||
| 	email, user, access_token, err := parseCookieValue( | ||||
| 		"michael.bland@gsa.gov", nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
| 	assert.Equal(t, "michael.bland", user) | ||||
| 	assert.Equal(t, "", access_token) | ||||
| } | ||||
| 
 | ||||
| func TestParseCookieValueWithAccessTokenAndNilCipher(t *testing.T) { | ||||
| 	email, user, access_token, err := parseCookieValue( | ||||
| 		"michael.bland@gsa.gov|access_token", nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
| 	assert.Equal(t, "michael.bland", user) | ||||
| 	assert.Equal(t, "", access_token) | ||||
| } | ||||
| 
 | ||||
| func TestBuildAndParseCookieValueWithAccessToken(t *testing.T) { | ||||
| 	aes_cipher, err := aes.NewCipher([]byte("0123456789abcdef")) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	value, err := buildCookieValue("michael.bland@gsa.gov", aes_cipher, | ||||
| 		"access_token") | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	prefix := "michael.bland@gsa.gov|" | ||||
| 	if !strings.HasPrefix(value, prefix) { | ||||
| 		t.Fatal("cookie value does not start with \"%s\": %s", | ||||
| 			prefix, value) | ||||
| 	} | ||||
| 
 | ||||
| 	email, user, access_token, err := parseCookieValue(value, aes_cipher) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
| 	assert.Equal(t, "michael.bland", user) | ||||
| 	assert.Equal(t, "access_token", access_token) | ||||
| } | ||||
							
								
								
									
										283
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										283
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -1,8 +1,6 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
|  | @ -16,6 +14,7 @@ import ( | |||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/bitly/oauth2_proxy/cookie" | ||||
| 	"github.com/bitly/oauth2_proxy/providers" | ||||
| ) | ||||
| 
 | ||||
|  | @ -44,7 +43,7 @@ type OauthProxy struct { | |||
| 	serveMux            http.Handler | ||||
| 	PassBasicAuth       bool | ||||
| 	PassAccessToken     bool | ||||
| 	AesCipher           cipher.Block | ||||
| 	CookieCipher        *cookie.Cipher | ||||
| 	skipAuthRegex       []string | ||||
| 	compiledRegex       []*regexp.Regexp | ||||
| 	templates           *template.Template | ||||
|  | @ -116,10 +115,10 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | |||
| 
 | ||||
| 	log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain, refresh) | ||||
| 
 | ||||
| 	var aes_cipher cipher.Block | ||||
| 	var cipher *cookie.Cipher | ||||
| 	if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { | ||||
| 		var err error | ||||
| 		aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) | ||||
| 		cipher, err = cookie.NewCipher(opts.CookieSecret) | ||||
| 		if err != nil { | ||||
| 			log.Fatal("error creating AES cipher with "+ | ||||
| 				"cookie-secret ", opts.CookieSecret, ": ", err) | ||||
|  | @ -150,7 +149,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | |||
| 		compiledRegex:   opts.CompiledRegex, | ||||
| 		PassBasicAuth:   opts.PassBasicAuth, | ||||
| 		PassAccessToken: opts.PassAccessToken, | ||||
| 		AesCipher:       aes_cipher, | ||||
| 		CookieCipher:    cipher, | ||||
| 		templates:       loadTemplates(opts.CustomTemplatesDir), | ||||
| 	} | ||||
| } | ||||
|  | @ -177,22 +176,20 @@ func (p *OauthProxy) displayCustomLoginForm() bool { | |||
| 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) redeemCode(host, code string) (string, string, error) { | ||||
| func (p *OauthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { | ||||
| 	if code == "" { | ||||
| 		return "", "", errors.New("missing code") | ||||
| 		return nil, errors.New("missing code") | ||||
| 	} | ||||
| 	redirectUri := p.GetRedirectURI(host) | ||||
| 	body, access_token, err := p.provider.Redeem(redirectUri, code) | ||||
| 	s, err = p.provider.Redeem(redirectUri, code) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	email, err := p.provider.GetEmailAddress(body, access_token) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	if s.Email == "" { | ||||
| 		s.Email, err = p.provider.GetEmailAddress(s) | ||||
| 	} | ||||
| 
 | ||||
| 	return access_token, email, nil | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||
|  | @ -208,9 +205,8 @@ func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time | |||
| 	} | ||||
| 
 | ||||
| 	if value != "" { | ||||
| 		value = signedCookieValue(p.CookieSeed, p.CookieName, value, now) | ||||
| 		value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) | ||||
| 	} | ||||
| 
 | ||||
| 	return &http.Cookie{ | ||||
| 		Name:     p.CookieName, | ||||
| 		Value:    value, | ||||
|  | @ -230,35 +226,34 @@ func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val st | |||
| 	http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now())) | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) { | ||||
| 	var value string | ||||
| 	var timestamp time.Time | ||||
| 	cookie, err := req.Cookie(p.CookieName) | ||||
| 	if err == nil { | ||||
| 		value, timestamp, ok = validateCookie(cookie, p.CookieSeed, p.CookieExpire) | ||||
| 		if ok { | ||||
| 			email, user, access_token, err = parseCookieValue(value, p.AesCipher) | ||||
| 		} | ||||
| 	} | ||||
| func (p *OauthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { | ||||
| 	var age time.Duration | ||||
| 	c, err := req.Cookie(p.CookieName) | ||||
| 	if err != nil { | ||||
| 		log.Printf(err.Error()) | ||||
| 		ok = false | ||||
| 	} else if ok && p.CookieRefresh != time.Duration(0) { | ||||
| 		refresh := timestamp.Add(p.CookieRefresh) | ||||
| 		if refresh.Before(time.Now()) { | ||||
| 			log.Printf("refreshing %s old session for %s (refresh after %s)", time.Now().Sub(timestamp), email, p.CookieRefresh) | ||||
| 			ok = p.Validator(email) | ||||
| 			log.Printf("re-validating %s valid:%v", email, ok) | ||||
| 			if ok { | ||||
| 				ok = p.provider.ValidateToken(access_token) | ||||
| 				log.Printf("re-validating access token. valid:%v", ok) | ||||
| 		// always http.ErrNoCookie
 | ||||
| 		return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) | ||||
| 	} | ||||
| 	val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) | ||||
| 	if !ok { | ||||
| 		return nil, age, errors.New("Cookie Signature not valid") | ||||
| 	} | ||||
| 
 | ||||
| 	session, err := p.provider.SessionFromCookie(val, p.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return nil, age, err | ||||
| 	} | ||||
| 
 | ||||
| 	age = time.Now().Truncate(time.Second).Sub(timestamp) | ||||
| 	return session, age, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { | ||||
| 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 			if ok { | ||||
| 	p.SetCookie(rw, req, value) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) { | ||||
|  | @ -344,54 +339,61 @@ func (p *OauthProxy) GetRedirect(req *http.Request) (string, error) { | |||
| 	return redirect, err | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	// check if this is a redirect back at the end of oauth
 | ||||
| 	remoteAddr := req.RemoteAddr | ||||
| 	if req.Header.Get("X-Real-IP") != "" { | ||||
| 		remoteAddr += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP")) | ||||
| 	} | ||||
| 
 | ||||
| 	var ok bool | ||||
| 	var user string | ||||
| 	var email string | ||||
| 	var access_token string | ||||
| 
 | ||||
| 	if req.URL.Path == p.RobotsPath { | ||||
| 		p.RobotsTxt(rw) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if req.URL.Path == p.PingPath { | ||||
| 		p.PingPage(rw) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| func (p *OauthProxy) IsWhitelistedPath(path string) (ok bool) { | ||||
| 	for _, u := range p.compiledRegex { | ||||
| 		match := u.MatchString(req.URL.Path) | ||||
| 		if match { | ||||
| 			p.serveMux.ServeHTTP(rw, req) | ||||
| 		ok = u.MatchString(path) | ||||
| 		if ok { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| 	if req.URL.Path == p.SignInPath { | ||||
| func getRemoteAddr(req *http.Request) (s string) { | ||||
| 	s = req.RemoteAddr | ||||
| 	if req.Header.Get("X-Real-IP") != "" { | ||||
| 		s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP")) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	switch path := req.URL.Path; { | ||||
| 	case path == p.RobotsPath: | ||||
| 		p.RobotsTxt(rw) | ||||
| 	case path == p.PingPath: | ||||
| 		p.PingPage(rw) | ||||
| 	case p.IsWhitelistedPath(path): | ||||
| 		p.serveMux.ServeHTTP(rw, req) | ||||
| 	case path == p.SignInPath: | ||||
| 		p.SignIn(rw, req) | ||||
| 	case path == p.OauthStartPath: | ||||
| 		p.OauthStart(rw, req) | ||||
| 	case path == p.OauthCallbackPath: | ||||
| 		p.OauthCallback(rw, req) | ||||
| 	default: | ||||
| 		p.Proxy(rw, req) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||
| 	redirect, err := p.GetRedirect(req) | ||||
| 	if err != nil { | ||||
| 		p.ErrorPage(rw, 500, "Internal Error", err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 		user, ok = p.ManualSignIn(rw, req) | ||||
| 	user, ok := p.ManualSignIn(rw, req) | ||||
| 	if ok { | ||||
| 			p.SetCookie(rw, req, user) | ||||
| 		session := &providers.SessionState{User: user} | ||||
| 		p.SaveSession(rw, req, session) | ||||
| 		http.Redirect(rw, req, redirect, 302) | ||||
| 	} else { | ||||
| 		p.SignInPage(rw, req, 200) | ||||
| 	} | ||||
| 		return | ||||
| 	} | ||||
| 	if req.URL.Path == p.OauthStartPath { | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) OauthStart(rw http.ResponseWriter, req *http.Request) { | ||||
| 	redirect, err := p.GetRedirect(req) | ||||
| 	if err != nil { | ||||
| 		p.ErrorPage(rw, 500, "Internal Error", err.Error()) | ||||
|  | @ -399,9 +401,11 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | |||
| 	} | ||||
| 	redirectURI := p.GetRedirectURI(req.Host) | ||||
| 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302) | ||||
| 		return | ||||
| 	} | ||||
| 	if req.URL.Path == p.OauthCallbackPath { | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) OauthCallback(rw http.ResponseWriter, req *http.Request) { | ||||
| 	remoteAddr := getRemoteAddr(req) | ||||
| 
 | ||||
| 	// finish the oauth cycle
 | ||||
| 	err := req.ParseForm() | ||||
| 	if err != nil { | ||||
|  | @ -414,10 +418,10 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 		access_token, email, err = p.redeemCode(req.Host, req.Form.Get("code")) | ||||
| 	session, err := p.redeemCode(req.Host, req.Form.Get("code")) | ||||
| 	if err != nil { | ||||
| 		log.Printf("%s error redeeming code %s", remoteAddr, err) | ||||
| 			p.ErrorPage(rw, 500, "Internal Error", err.Error()) | ||||
| 		p.ErrorPage(rw, 500, "Internal Error", "Internal Error") | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | @ -427,73 +431,134 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | |||
| 	} | ||||
| 
 | ||||
| 	// set cookie, or deny
 | ||||
| 		if p.Validator(email) { | ||||
| 			log.Printf("%s authenticating %s completed", remoteAddr, email) | ||||
| 			value, err := buildCookieValue( | ||||
| 				email, p.AesCipher, access_token) | ||||
| 	if p.Validator(session.Email) { | ||||
| 		log.Printf("%s authentication complete %s", remoteAddr, session) | ||||
| 		err := p.SaveSession(rw, req, session) | ||||
| 		if err != nil { | ||||
| 				log.Printf("%s", err) | ||||
| 			log.Printf("%s %s", remoteAddr, err) | ||||
| 			p.ErrorPage(rw, 500, "Internal Error", "Internal Error") | ||||
| 			return | ||||
| 		} | ||||
| 			p.SetCookie(rw, req, value) | ||||
| 		http.Redirect(rw, req, redirect, 302) | ||||
| 			return | ||||
| 	} else { | ||||
| 			log.Printf("validating: %s is unauthorized") | ||||
| 		log.Printf("%s Permission Denied: %q is unauthorized", remoteAddr, session.Email) | ||||
| 		p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||
| 	var saveSession, clearSession, revalidated bool | ||||
| 	remoteAddr := getRemoteAddr(req) | ||||
| 
 | ||||
| 	session, sessionAge, err := p.LoadCookiedSession(req) | ||||
| 	if err != nil { | ||||
| 		log.Printf("%s %s", remoteAddr, err) | ||||
| 	} | ||||
| 	if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { | ||||
| 		log.Printf("%s refreshing %s old session cookie for %s (refresh after %s)", remoteAddr, sessionAge, session, p.CookieRefresh) | ||||
| 		saveSession = true | ||||
| 	} | ||||
| 
 | ||||
| 	if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil { | ||||
| 		log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) | ||||
| 		clearSession = true | ||||
| 		session = nil | ||||
| 	} else if ok { | ||||
| 		saveSession = true | ||||
| 		revalidated = true | ||||
| 	} | ||||
| 
 | ||||
| 	if session != nil && session.IsExpired() { | ||||
| 		log.Printf("%s removing session. token expired %s", remoteAddr, session) | ||||
| 		session = nil | ||||
| 		saveSession = false | ||||
| 		clearSession = true | ||||
| 	} | ||||
| 
 | ||||
| 	if saveSession && !revalidated && session.AccessToken != "" { | ||||
| 		if !p.provider.ValidateSessionState(session) { | ||||
| 			log.Printf("%s removing session. error validating %s", remoteAddr, session) | ||||
| 			saveSession = false | ||||
| 			session = nil | ||||
| 			clearSession = true | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if saveSession && session.Email != "" && !p.Validator(session.Email) { | ||||
| 		log.Printf("%s Permission Denied: removing session %s", remoteAddr, session) | ||||
| 		session = nil | ||||
| 		saveSession = false | ||||
| 		clearSession = true | ||||
| 	} | ||||
| 
 | ||||
| 	if saveSession { | ||||
| 		err := p.SaveSession(rw, req, session) | ||||
| 		if err != nil { | ||||
| 			log.Printf("%s %s", remoteAddr, err) | ||||
| 			p.ErrorPage(rw, 500, "Internal Error", "Internal Error") | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if !ok { | ||||
| 		email, user, access_token, ok = p.ProcessCookie(rw, req) | ||||
| 	if clearSession { | ||||
| 		p.ClearCookie(rw, req) | ||||
| 	} | ||||
| 
 | ||||
| 	if !ok { | ||||
| 		user, ok = p.CheckBasicAuth(req) | ||||
| 	if session == nil { | ||||
| 		session, err = p.CheckBasicAuth(req) | ||||
| 		if err != nil { | ||||
| 			log.Printf("%s %s", remoteAddr, err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if !ok { | ||||
| 	if session == nil { | ||||
| 		p.SignInPage(rw, req, 403) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// At this point, the user is authenticated. proxy normally
 | ||||
| 	if p.PassBasicAuth { | ||||
| 		req.SetBasicAuth(user, "") | ||||
| 		req.Header["X-Forwarded-User"] = []string{user} | ||||
| 		req.Header["X-Forwarded-Email"] = []string{email} | ||||
| 		req.SetBasicAuth(session.User, "") | ||||
| 		req.Header["X-Forwarded-User"] = []string{session.User} | ||||
| 		if session.Email != "" { | ||||
| 			req.Header["X-Forwarded-Email"] = []string{session.Email} | ||||
| 		} | ||||
| 	if p.PassAccessToken { | ||||
| 		req.Header["X-Forwarded-Access-Token"] = []string{access_token} | ||||
| 	} | ||||
| 	if email == "" { | ||||
| 		rw.Header().Set("GAP-Auth", user) | ||||
| 	if p.PassAccessToken && session.AccessToken != "" { | ||||
| 		req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken} | ||||
| 	} | ||||
| 	if session.Email == "" { | ||||
| 		rw.Header().Set("GAP-Auth", session.User) | ||||
| 	} else { | ||||
| 		rw.Header().Set("GAP-Auth", email) | ||||
| 		rw.Header().Set("GAP-Auth", session.Email) | ||||
| 	} | ||||
| 
 | ||||
| 	p.serveMux.ServeHTTP(rw, req) | ||||
| } | ||||
| 
 | ||||
| func (p *OauthProxy) CheckBasicAuth(req *http.Request) (string, bool) { | ||||
| func (p *OauthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { | ||||
| 	if p.HtpasswdFile == nil { | ||||
| 		return "", false | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	s := strings.SplitN(req.Header.Get("Authorization"), " ", 2) | ||||
| 	auth := req.Header.Get("Authorization") | ||||
| 	if auth == "" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	s := strings.SplitN(auth, " ", 2) | ||||
| 	if len(s) != 2 || s[0] != "Basic" { | ||||
| 		return "", false | ||||
| 		return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) | ||||
| 	} | ||||
| 	b, err := base64.StdEncoding.DecodeString(s[1]) | ||||
| 	if err != nil { | ||||
| 		return "", false | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	pair := strings.SplitN(string(b), ":", 2) | ||||
| 	if len(pair) != 2 { | ||||
| 		return "", false | ||||
| 		return nil, fmt.Errorf("invalid format %s", b) | ||||
| 	} | ||||
| 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | ||||
| 		log.Printf("authenticated %q via basic auth", pair[0]) | ||||
| 		return pair[0], true | ||||
| 		return &providers.SessionState{User: pair[0]}, nil | ||||
| 	} | ||||
| 	return "", false | ||||
| 	return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) | ||||
| } | ||||
|  |  | |||
|  | @ -94,11 +94,11 @@ type TestProvider struct { | |||
| 	ValidToken   bool | ||||
| } | ||||
| 
 | ||||
| func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||
| func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { | ||||
| 	return tp.EmailAddress, nil | ||||
| } | ||||
| 
 | ||||
| func (tp *TestProvider) ValidateToken(access_token string) bool { | ||||
| func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { | ||||
| 	return tp.ValidToken | ||||
| } | ||||
| 
 | ||||
|  | @ -378,97 +378,73 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { | |||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) MakeCookie(value, access_token string, ref time.Time) *http.Cookie { | ||||
| 	cookie_value, _ := buildCookieValue(value, p.proxy.AesCipher, access_token) | ||||
| 	return p.proxy.MakeCookie(p.req, cookie_value, p.opts.CookieExpire, ref) | ||||
| func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie { | ||||
| 	return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref) | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) AddCookie(value, access_token string) { | ||||
| 	p.req.AddCookie(p.MakeCookie(value, access_token, time.Now())) | ||||
| func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { | ||||
| 	value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref)) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) ProcessCookie() (email, user, access_token string, ok bool) { | ||||
| 	return p.proxy.ProcessCookie(p.rw, p.req) | ||||
| func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { | ||||
| 	return p.proxy.LoadCookiedSession(p.req) | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookie(t *testing.T) { | ||||
| func TestLoadCookiedSession(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 
 | ||||
| 	pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token") | ||||
| 	email, user, access_token, ok := pc_test.ProcessCookie() | ||||
| 	assert.Equal(t, true, ok) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
| 	assert.Equal(t, "michael.bland", user) | ||||
| 	assert.Equal(t, "my_access_token", access_token) | ||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pc_test.SaveSession(startSession, time.Now()) | ||||
| 
 | ||||
| 	session, _, err := pc_test.LoadCookiedSession() | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, startSession.Email, session.Email) | ||||
| 	assert.Equal(t, "michael.bland", session.User) | ||||
| 	assert.Equal(t, startSession.AccessToken, session.AccessToken) | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookieNoCookieError(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 	_, _, _, ok := pc_test.ProcessCookie() | ||||
| 	assert.Equal(t, false, ok) | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 	value, _ := buildCookieValue("michael.bland@gsa.gov", | ||||
| 		pc_test.proxy.AesCipher, "my_access_token") | ||||
| 	pc_test.req.AddCookie(pc_test.proxy.MakeCookie( | ||||
| 		pc_test.req, value+"some bogus bytes", | ||||
| 		pc_test.opts.CookieExpire, time.Now())) | ||||
| 	_, _, _, ok := pc_test.ProcessCookie() | ||||
| 	assert.Equal(t, false, ok) | ||||
| 	session, _, err := pc_test.LoadCookiedSession() | ||||
| 	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expected nil session. got %#v", session) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookieRefreshNotSet(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||
| 	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "", reference) | ||||
| 	pc_test.req.AddCookie(cookie) | ||||
| 
 | ||||
| 	_, _, _, ok := pc_test.ProcessCookie() | ||||
| 	assert.Equal(t, true, ok) | ||||
| 	assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) | ||||
| } | ||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pc_test.SaveSession(startSession, reference) | ||||
| 
 | ||||
| func TestProcessCookieRefresh(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||
| 	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) | ||||
| 	pc_test.req.AddCookie(cookie) | ||||
| 
 | ||||
| 	pc_test.proxy.CookieRefresh = time.Hour | ||||
| 	_, _, _, ok := pc_test.ProcessCookie() | ||||
| 	assert.Equal(t, true, ok) | ||||
| 	assert.NotEqual(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(-30) * time.Minute) | ||||
| 	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) | ||||
| 	pc_test.req.AddCookie(cookie) | ||||
| 
 | ||||
| 	pc_test.proxy.CookieRefresh = time.Hour | ||||
| 	_, _, _, ok := pc_test.ProcessCookie() | ||||
| 	assert.Equal(t, true, ok) | ||||
| 	assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) | ||||
| 	session, age, err := pc_test.LoadCookiedSession() | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	if age < time.Duration(-2)*time.Hour { | ||||
| 		t.Errorf("cookie too young %v", age) | ||||
| 	} | ||||
| 	assert.Equal(t, startSession.Email, session.Email) | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 	pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||
| 	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) | ||||
| 	pc_test.req.AddCookie(cookie) | ||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pc_test.SaveSession(startSession, reference) | ||||
| 
 | ||||
| 	if _, _, _, ok := pc_test.ProcessCookie(); ok { | ||||
| 		t.Error("ProcessCookie() should have failed") | ||||
| 	} | ||||
| 	if set_cookie := pc_test.rw.HeaderMap["Set-Cookie"]; set_cookie != nil { | ||||
| 		t.Error("expected Set-Cookie to be nil, instead was: ", set_cookie) | ||||
| 	session, _, err := pc_test.LoadCookiedSession() | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expected nil session %#v", session) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -476,44 +452,13 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | |||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 	pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||
| 	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) | ||||
| 	pc_test.req.AddCookie(cookie) | ||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pc_test.SaveSession(startSession, reference) | ||||
| 
 | ||||
| 	pc_test.proxy.CookieRefresh = time.Hour | ||||
| 	if _, _, _, ok := pc_test.ProcessCookie(); ok { | ||||
| 		t.Error("ProcessCookie() should have failed") | ||||
| 	} | ||||
| 	if set_cookie := pc_test.rw.HeaderMap["Set-Cookie"]; set_cookie != nil { | ||||
| 		t.Error("expected Set-Cookie to be nil, instead was: ", set_cookie) | ||||
| 	session, _, err := pc_test.LoadCookiedSession() | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expected nil session %#v", session) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTest(ProcessCookieTestOpts{ | ||||
| 		provider_validate_cookie_response: false, | ||||
| 	}) | ||||
| 	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(-24) * time.Hour) | ||||
| 	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) | ||||
| 	pc_test.req.AddCookie(cookie) | ||||
| 
 | ||||
| 	pc_test.proxy.CookieRefresh = time.Hour | ||||
| 	_, _, _, ok := pc_test.ProcessCookie() | ||||
| 	assert.Equal(t, false, ok) | ||||
| 	assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) { | ||||
| 	pc_test := NewProcessCookieTestWithDefaults() | ||||
| 	pc_test.validate_user = false | ||||
| 
 | ||||
| 	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||
| 	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) | ||||
| 	pc_test.req.AddCookie(cookie) | ||||
| 
 | ||||
| 	pc_test.proxy.CookieRefresh = time.Hour | ||||
| 	_, _, _, ok := pc_test.ProcessCookie() | ||||
| 	assert.Equal(t, false, ok) | ||||
| 	assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) | ||||
| } | ||||
|  |  | |||
|  | @ -2,8 +2,10 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| ) | ||||
|  | @ -138,7 +140,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { | |||
| 	return false, nil | ||||
| } | ||||
| 
 | ||||
| func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||
| func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| 
 | ||||
| 	var emails []struct { | ||||
| 		Email   string `json:"email"` | ||||
|  | @ -148,31 +150,34 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri | |||
| 	// if we require an Org or Team, check that first
 | ||||
| 	if p.Org != "" { | ||||
| 		if p.Team != "" { | ||||
| 			if ok, err := p.hasOrgAndTeam(access_token); err != nil || !ok { | ||||
| 			if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { | ||||
| 				return "", err | ||||
| 			} | ||||
| 		} else { | ||||
| 			if ok, err := p.hasOrg(access_token); err != nil || !ok { | ||||
| 			if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { | ||||
| 				return "", err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	params := url.Values{ | ||||
| 		"access_token": {access_token}, | ||||
| 		"access_token": {s.AccessToken}, | ||||
| 	} | ||||
| 	endpoint := "https://api.github.com/user/emails?" + params.Encode() | ||||
| 	resp, err := http.DefaultClient.Get(endpoint) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body) | ||||
| 	} else { | ||||
| 		log.Printf("got %d from %q %s", resp.StatusCode, endpoint, body) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := json.Unmarshal(body, &emails); err != nil { | ||||
|  | @ -185,9 +190,5 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return "", nil | ||||
| } | ||||
| 
 | ||||
| func (p *GitHubProvider) ValidateToken(access_token string) bool { | ||||
| 	return validateToken(p, access_token, nil) | ||||
| 	return "", errors.New("no email address found") | ||||
| } | ||||
|  |  | |||
|  | @ -7,9 +7,11 @@ import ( | |||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type GoogleProvider struct { | ||||
|  | @ -43,18 +45,11 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { | |||
| 	return &GoogleProvider{ProviderData: p} | ||||
| } | ||||
| 
 | ||||
| func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||
| 	var response struct { | ||||
| 		IdToken string `json:"id_token"` | ||||
| 	} | ||||
| 
 | ||||
| 	if err := json.Unmarshal(body, &response); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| func emailFromIdToken(idToken string) (string, error) { | ||||
| 
 | ||||
| 	// id_token is a base64 encode ID token payload
 | ||||
| 	// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | ||||
| 	jwt := strings.Split(response.IdToken, ".") | ||||
| 	jwt := strings.Split(idToken, ".") | ||||
| 	b, err := jwtDecodeSegment(jwt[1]) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
|  | @ -62,6 +57,7 @@ func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (stri | |||
| 
 | ||||
| 	var email struct { | ||||
| 		Email         string `json:"email"` | ||||
| 		EmailVerified bool   `json:"email_verified"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(b, &email) | ||||
| 	if err != nil { | ||||
|  | @ -70,6 +66,9 @@ func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (stri | |||
| 	if email.Email == "" { | ||||
| 		return "", errors.New("missing email") | ||||
| 	} | ||||
| 	if !email.EmailVerified { | ||||
| 		return "", fmt.Errorf("email %s not listed as verified", email.Email) | ||||
| 	} | ||||
| 	return email.Email, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -81,11 +80,7 @@ func jwtDecodeSegment(seg string) ([]byte, error) { | |||
| 	return base64.URLEncoding.DecodeString(seg) | ||||
| } | ||||
| 
 | ||||
| func (p *GoogleProvider) ValidateToken(access_token string) bool { | ||||
| 	return validateToken(p, access_token, nil) | ||||
| } | ||||
| 
 | ||||
| func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token string, err error) { | ||||
| func (p *GoogleProvider) Redeem(redirectUrl, code string) (s *SessionState, err error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
|  | @ -108,6 +103,7 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st | |||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
|  | @ -122,17 +118,44 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st | |||
| 	var jsonResponse struct { | ||||
| 		AccessToken  string `json:"access_token"` | ||||
| 		RefreshToken string `json:"refresh_token"` | ||||
| 		ExpiresIn    int64  `json:"expires_in"` | ||||
| 		IdToken      string `json:"id_token"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &jsonResponse) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	token, err = p.redeemRefreshToken(jsonResponse.RefreshToken) | ||||
| 	var email string | ||||
| 	email, err = emailFromIdToken(jsonResponse.IdToken) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	s = &SessionState{ | ||||
| 		AccessToken:  jsonResponse.AccessToken, | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||
| 		RefreshToken: jsonResponse.RefreshToken, | ||||
| 		Email:        email, | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, err error) { | ||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	newToken, duration, err := p.redeemRefreshToken(s.RefreshToken) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	origExpiration := s.ExpiresOn | ||||
| 	s.AccessToken = newToken | ||||
| 	s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) | ||||
| 	log.Printf("refreshed access token %s (expired on %s)", s, origExpiration) | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) { | ||||
| 	// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
 | ||||
| 	params := url.Values{} | ||||
| 	params.Add("client_id", p.ClientID) | ||||
|  | @ -162,12 +185,15 @@ func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	var jsonResponse struct { | ||||
| 	var data struct { | ||||
| 		AccessToken string `json:"access_token"` | ||||
| 		ExpiresIn   int64  `json:"expires_in"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &jsonResponse) | ||||
| 	err = json.Unmarshal(body, &data) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	return jsonResponse.AccessToken, nil | ||||
| 	token = data.AccessToken | ||||
| 	expires = time.Duration(data.ExpiresIn) * time.Second | ||||
| 	return | ||||
| } | ||||
|  |  | |||
|  | @ -3,11 +3,22 @@ package providers | |||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"github.com/bmizerany/assert" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/bmizerany/assert" | ||||
| ) | ||||
| 
 | ||||
| func newRedeemServer(body []byte) (*url.URL, *httptest.Server) { | ||||
| 	s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { | ||||
| 		rw.Write(body) | ||||
| 	})) | ||||
| 	u, _ := url.Parse(s.URL) | ||||
| 	return u, s | ||||
| } | ||||
| 
 | ||||
| func newGoogleProvider() *GoogleProvider { | ||||
| 	return NewGoogleProvider( | ||||
| 		&ProviderData{ | ||||
|  | @ -66,63 +77,88 @@ func TestGoogleProviderOverrides(t *testing.T) { | |||
| 	assert.Equal(t, "profile", p.Data().Scope) | ||||
| } | ||||
| 
 | ||||
| func TestGoogleProviderGetEmailAddress(t *testing.T) { | ||||
| 	p := newGoogleProvider() | ||||
| 	body, err := json.Marshal( | ||||
| 		struct { | ||||
| type redeemResponse struct { | ||||
| 	AccessToken  string `json:"access_token"` | ||||
| 	RefreshToken string `json:"refresh_token"` | ||||
| 	ExpiresIn    int64  `json:"expires_in"` | ||||
| 	IdToken      string `json:"id_token"` | ||||
| 		}{ | ||||
| 			IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov"}`)), | ||||
| 		}, | ||||
| 	) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	email, err := p.GetEmailAddress(body, "ignored access_token") | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
| 	assert.Equal(t, nil, err) | ||||
| } | ||||
| 
 | ||||
| func TestGoogleProviderGetEmailAddress(t *testing.T) { | ||||
| 	p := newGoogleProvider() | ||||
| 	body, err := json.Marshal(redeemResponse{ | ||||
| 		AccessToken:  "a1234", | ||||
| 		ExpiresIn:    10, | ||||
| 		RefreshToken: "refresh12345", | ||||
| 		IdToken:      "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), | ||||
| 	}) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	var server *httptest.Server | ||||
| 	p.RedeemUrl, server = newRedeemServer(body) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	session, err := p.Redeem("http://redirect/", "code1234") | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.NotEqual(t, session, nil) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", session.Email) | ||||
| 	assert.Equal(t, "a1234", session.AccessToken) | ||||
| 	assert.Equal(t, "refresh12345", session.RefreshToken) | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { | ||||
| 	p := newGoogleProvider() | ||||
| 	body, err := json.Marshal( | ||||
| 		struct { | ||||
| 			IdToken string `json:"id_token"` | ||||
| 		}{ | ||||
| 	body, err := json.Marshal(redeemResponse{ | ||||
| 		AccessToken: "a1234", | ||||
| 		IdToken:     "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, | ||||
| 		}, | ||||
| 	) | ||||
| 	}) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	email, err := p.GetEmailAddress(body, "ignored access_token") | ||||
| 	assert.Equal(t, "", email) | ||||
| 	var server *httptest.Server | ||||
| 	p.RedeemUrl, server = newRedeemServer(body) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	session, err := p.Redeem("http://redirect/", "code1234") | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expect nill session %#v", session) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { | ||||
| 	p := newGoogleProvider() | ||||
| 
 | ||||
| 	body, err := json.Marshal( | ||||
| 		struct { | ||||
| 			IdToken string `json:"id_token"` | ||||
| 		}{ | ||||
| 	body, err := json.Marshal(redeemResponse{ | ||||
| 		AccessToken: "a1234", | ||||
| 		IdToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), | ||||
| 		}, | ||||
| 	) | ||||
| 	}) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	email, err := p.GetEmailAddress(body, "ignored access_token") | ||||
| 	assert.Equal(t, "", email) | ||||
| 	var server *httptest.Server | ||||
| 	p.RedeemUrl, server = newRedeemServer(body) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	session, err := p.Redeem("http://redirect/", "code1234") | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expect nill session %#v", session) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { | ||||
| 	p := newGoogleProvider() | ||||
| 	body, err := json.Marshal( | ||||
| 		struct { | ||||
| 			IdToken string `json:"id_token"` | ||||
| 		}{ | ||||
| 	body, err := json.Marshal(redeemResponse{ | ||||
| 		AccessToken: "a1234", | ||||
| 		IdToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), | ||||
| 		}, | ||||
| 	) | ||||
| 	}) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	email, err := p.GetEmailAddress(body, "ignored access_token") | ||||
| 	assert.Equal(t, "", email) | ||||
| 	var server *httptest.Server | ||||
| 	p.RedeemUrl, server = newRedeemServer(body) | ||||
| 	defer server.Close() | ||||
| 
 | ||||
| 	session, err := p.Redeem("http://redirect/", "code1234") | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expect nill session %#v", session) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ import ( | |||
| 	"github.com/bitly/oauth2_proxy/api" | ||||
| ) | ||||
| 
 | ||||
| // validateToken returns true if token is valid
 | ||||
| func validateToken(p Provider, access_token string, header http.Header) bool { | ||||
| 	if access_token == "" || p.Data().ValidateUrl == nil { | ||||
| 		return false | ||||
|  | @ -20,12 +21,15 @@ func validateToken(p Provider, access_token string, header http.Header) bool { | |||
| 	} | ||||
| 	resp, err := api.RequestUnparsedResponse(endpoint, header) | ||||
| 	if err != nil { | ||||
| 		log.Printf("GET %s", endpoint) | ||||
| 		log.Printf("token validation request failed: %s", err) | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	body, _ := ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	log.Printf("%d GET %s %s", resp.StatusCode, endpoint, body) | ||||
| 
 | ||||
| 	if resp.StatusCode == 200 { | ||||
| 		return true | ||||
| 	} | ||||
|  |  | |||
|  | @ -1,36 +1,38 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/bmizerany/assert" | ||||
| 	"errors" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/bmizerany/assert" | ||||
| ) | ||||
| 
 | ||||
| type ValidateTokenTestProvider struct { | ||||
| type ValidateSessionStateTestProvider struct { | ||||
| 	*ProviderData | ||||
| } | ||||
| 
 | ||||
| func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||
| 	return "", nil | ||||
| func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| 	return "", errors.New("not implemented") | ||||
| } | ||||
| 
 | ||||
| // Note that we're testing the internal validateToken() used to implement
 | ||||
| // several Provider's ValidateToken() implementations
 | ||||
| func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool { | ||||
| // several Provider's ValidateSessionState() implementations
 | ||||
| func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| type ValidateTokenTest struct { | ||||
| type ValidateSessionStateTest struct { | ||||
| 	backend       *httptest.Server | ||||
| 	response_code int | ||||
| 	provider      *ValidateTokenTestProvider | ||||
| 	provider      *ValidateSessionStateTestProvider | ||||
| 	header        http.Header | ||||
| } | ||||
| 
 | ||||
| func NewValidateTokenTest() *ValidateTokenTest { | ||||
| 	var vt_test ValidateTokenTest | ||||
| func NewValidateSessionStateTest() *ValidateSessionStateTest { | ||||
| 	var vt_test ValidateSessionStateTest | ||||
| 
 | ||||
| 	vt_test.backend = httptest.NewServer( | ||||
| 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
|  | @ -59,7 +61,7 @@ func NewValidateTokenTest() *ValidateTokenTest { | |||
| 
 | ||||
| 		})) | ||||
| 	backend_url, _ := url.Parse(vt_test.backend.URL) | ||||
| 	vt_test.provider = &ValidateTokenTestProvider{ | ||||
| 	vt_test.provider = &ValidateSessionStateTestProvider{ | ||||
| 		ProviderData: &ProviderData{ | ||||
| 			ValidateUrl: &url.URL{ | ||||
| 				Scheme: "http", | ||||
|  | @ -72,18 +74,18 @@ func NewValidateTokenTest() *ValidateTokenTest { | |||
| 	return &vt_test | ||||
| } | ||||
| 
 | ||||
| func (vt_test *ValidateTokenTest) Close() { | ||||
| func (vt_test *ValidateSessionStateTest) Close() { | ||||
| 	vt_test.backend.Close() | ||||
| } | ||||
| 
 | ||||
| func TestValidateTokenValidToken(t *testing.T) { | ||||
| 	vt_test := NewValidateTokenTest() | ||||
| func TestValidateSessionStateValidToken(t *testing.T) { | ||||
| 	vt_test := NewValidateSessionStateTest() | ||||
| 	defer vt_test.Close() | ||||
| 	assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) | ||||
| } | ||||
| 
 | ||||
| func TestValidateTokenValidTokenWithHeaders(t *testing.T) { | ||||
| 	vt_test := NewValidateTokenTest() | ||||
| func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { | ||||
| 	vt_test := NewValidateSessionStateTest() | ||||
| 	defer vt_test.Close() | ||||
| 	vt_test.header = make(http.Header) | ||||
| 	vt_test.header.Set("Authorization", "Bearer foobar") | ||||
|  | @ -91,28 +93,28 @@ func TestValidateTokenValidTokenWithHeaders(t *testing.T) { | |||
| 		validateToken(vt_test.provider, "foobar", vt_test.header)) | ||||
| } | ||||
| 
 | ||||
| func TestValidateTokenEmptyToken(t *testing.T) { | ||||
| 	vt_test := NewValidateTokenTest() | ||||
| func TestValidateSessionStateEmptyToken(t *testing.T) { | ||||
| 	vt_test := NewValidateSessionStateTest() | ||||
| 	defer vt_test.Close() | ||||
| 	assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) | ||||
| } | ||||
| 
 | ||||
| func TestValidateTokenEmptyValidateUrl(t *testing.T) { | ||||
| 	vt_test := NewValidateTokenTest() | ||||
| func TestValidateSessionStateEmptyValidateUrl(t *testing.T) { | ||||
| 	vt_test := NewValidateSessionStateTest() | ||||
| 	defer vt_test.Close() | ||||
| 	vt_test.provider.Data().ValidateUrl = nil | ||||
| 	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) | ||||
| } | ||||
| 
 | ||||
| func TestValidateTokenRequestNetworkFailure(t *testing.T) { | ||||
| 	vt_test := NewValidateTokenTest() | ||||
| func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { | ||||
| 	vt_test := NewValidateSessionStateTest() | ||||
| 	// Close immediately to simulate a network failure
 | ||||
| 	vt_test.Close() | ||||
| 	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) | ||||
| } | ||||
| 
 | ||||
| func TestValidateTokenExpiredToken(t *testing.T) { | ||||
| 	vt_test := NewValidateTokenTest() | ||||
| func TestValidateSessionStateExpiredToken(t *testing.T) { | ||||
| 	vt_test := NewValidateSessionStateTest() | ||||
| 	defer vt_test.Close() | ||||
| 	vt_test.response_code = 401 | ||||
| 	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) | ||||
|  |  | |||
|  | @ -1,7 +1,6 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
|  | @ -49,16 +48,15 @@ func getLinkedInHeader(access_token string) http.Header { | |||
| 	return header | ||||
| } | ||||
| 
 | ||||
| func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||
| 	if access_token == "" { | ||||
| func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| 	if s.AccessToken == "" { | ||||
| 		return "", errors.New("missing access token") | ||||
| 	} | ||||
| 	params := url.Values{} | ||||
| 	req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", bytes.NewBufferString(params.Encode())) | ||||
| 	req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", nil) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	req.Header = getLinkedInHeader(access_token) | ||||
| 	req.Header = getLinkedInHeader(s.AccessToken) | ||||
| 
 | ||||
| 	json, err := api.Request(req) | ||||
| 	if err != nil { | ||||
|  | @ -74,6 +72,6 @@ func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (st | |||
| 	return email, nil | ||||
| } | ||||
| 
 | ||||
| func (p *LinkedInProvider) ValidateToken(access_token string) bool { | ||||
| 	return validateToken(p, access_token, getLinkedInHeader(access_token)) | ||||
| func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { | ||||
| 	return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) | ||||
| } | ||||
|  |  | |||
|  | @ -97,8 +97,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { | |||
| 	b_url, _ := url.Parse(b.URL) | ||||
| 	p := testLinkedInProvider(b_url.Host) | ||||
| 
 | ||||
| 	email, err := p.GetEmailAddress([]byte{}, | ||||
| 		"imaginary_access_token") | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "user@linkedin.com", email) | ||||
| } | ||||
|  | @ -113,7 +113,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { | |||
| 	// We'll trigger a request failure by using an unexpected access
 | ||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||
| 	// JSON to fail.
 | ||||
| 	email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") | ||||
| 	session := &SessionState{AccessToken: "unexpected_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
| } | ||||
|  | @ -125,7 +126,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | |||
| 	b_url, _ := url.Parse(b.URL) | ||||
| 	p := testLinkedInProvider(b_url.Host) | ||||
| 
 | ||||
| 	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
| } | ||||
|  |  | |||
|  | @ -42,9 +42,9 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider { | |||
| 	return &MyUsaProvider{ProviderData: p} | ||||
| } | ||||
| 
 | ||||
| func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (string, error) { | ||||
| func (p *MyUsaProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| 	req, err := http.NewRequest("GET", | ||||
| 		p.ProfileUrl.String()+"?access_token="+access_token, nil) | ||||
| 		p.ProfileUrl.String()+"?access_token="+s.AccessToken, nil) | ||||
| 	if err != nil { | ||||
| 		log.Printf("failed building request %s", err) | ||||
| 		return "", err | ||||
|  | @ -56,7 +56,3 @@ func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (strin | |||
| 	} | ||||
| 	return json.Get("email").String() | ||||
| } | ||||
| 
 | ||||
| func (p *MyUsaProvider) ValidateToken(access_token string) bool { | ||||
| 	return validateToken(p, access_token, nil) | ||||
| } | ||||
|  |  | |||
|  | @ -1,11 +1,12 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/bmizerany/assert" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/bmizerany/assert" | ||||
| ) | ||||
| 
 | ||||
| func updateUrl(url *url.URL, hostname string) { | ||||
|  | @ -102,7 +103,8 @@ func TestMyUsaProviderGetEmailAddress(t *testing.T) { | |||
| 	b_url, _ := url.Parse(b.URL) | ||||
| 	p := testMyUsaProvider(b_url.Host) | ||||
| 
 | ||||
| 	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
| } | ||||
|  | @ -119,7 +121,8 @@ func TestMyUsaProviderGetEmailAddressFailedRequest(t *testing.T) { | |||
| 	// We'll trigger a request failure by using an unexpected access
 | ||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||
| 	// JSON to fail.
 | ||||
| 	email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") | ||||
| 	session := &SessionState{AccessToken: "unexpected_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
| } | ||||
|  | @ -131,7 +134,8 @@ func TestMyUsaProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | |||
| 	b_url, _ := url.Parse(b.URL) | ||||
| 	p := testMyUsaProvider(b_url.Host) | ||||
| 
 | ||||
| 	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
| } | ||||
|  |  | |||
|  | @ -9,9 +9,11 @@ import ( | |||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/bitly/oauth2_proxy/cookie" | ||||
| ) | ||||
| 
 | ||||
| func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) { | ||||
| func (p *ProviderData) Redeem(redirectUrl, code string) (s *SessionState, err error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
|  | @ -23,24 +25,28 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri | |||
| 	params.Add("client_secret", p.ClientSecret) | ||||
| 	params.Add("code", code) | ||||
| 	params.Add("grant_type", "authorization_code") | ||||
| 	req, err := http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) | ||||
| 	var req *http.Request | ||||
| 	req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) | ||||
| 	if err != nil { | ||||
| 		return nil, "", err | ||||
| 		return | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 
 | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	var resp *http.Response | ||||
| 	resp, err = http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, "", err | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, "", err | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return body, "", fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) | ||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// blindly try json and x-www-form-urlencoded
 | ||||
|  | @ -49,11 +55,23 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri | |||
| 	} | ||||
| 	err = json.Unmarshal(body, &jsonResponse) | ||||
| 	if err == nil { | ||||
| 		return body, jsonResponse.AccessToken, nil | ||||
| 		s = &SessionState{ | ||||
| 			AccessToken: jsonResponse.AccessToken, | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	v, err := url.ParseQuery(string(body)) | ||||
| 	return body, v.Get("access_token"), err | ||||
| 	var v url.Values | ||||
| 	v, err = url.ParseQuery(string(body)) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if a := v.Get("access_token"); a != "" { | ||||
| 		s = &SessionState{AccessToken: a} | ||||
| 	} else { | ||||
| 		err = fmt.Errorf("no access token found %s", body) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // GetLoginURL with typical oauth parameters
 | ||||
|  | @ -72,3 +90,26 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string { | |||
| 	a.RawQuery = params.Encode() | ||||
| 	return a.String() | ||||
| } | ||||
| 
 | ||||
| // CookieForSession serializes a session state for storage in a cookie
 | ||||
| func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { | ||||
| 	return s.EncodeSessionState(c) | ||||
| } | ||||
| 
 | ||||
| // SessionFromCookie deserializes a session from a cookie value
 | ||||
| func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { | ||||
| 	return DecodeSessionState(v, c) | ||||
| } | ||||
| 
 | ||||
| func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { | ||||
| 	return "", errors.New("not implemented") | ||||
| } | ||||
| 
 | ||||
| func (p *ProviderData) ValidateSessionState(s *SessionState) bool { | ||||
| 	return validateToken(p, s.AccessToken, nil) | ||||
| } | ||||
| 
 | ||||
| // RefreshSessionIfNeeded
 | ||||
| func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||
| 	return false, nil | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,17 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/bmizerany/assert" | ||||
| ) | ||||
| 
 | ||||
| func TestRefresh(t *testing.T) { | ||||
| 	p := &ProviderData{} | ||||
| 	refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ | ||||
| 		ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), | ||||
| 	}) | ||||
| 	assert.Equal(t, false, refreshed) | ||||
| 	assert.Equal(t, nil, err) | ||||
| } | ||||
|  | @ -1,11 +1,18 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/bitly/oauth2_proxy/cookie" | ||||
| ) | ||||
| 
 | ||||
| type Provider interface { | ||||
| 	Data() *ProviderData | ||||
| 	GetEmailAddress(body []byte, access_token string) (string, error) | ||||
| 	Redeem(string, string) ([]byte, string, error) | ||||
| 	ValidateToken(access_token string) bool | ||||
| 	GetEmailAddress(*SessionState) (string, error) | ||||
| 	Redeem(string, string) (*SessionState, error) | ||||
| 	ValidateSessionState(*SessionState) bool | ||||
| 	GetLoginURL(redirectURI, finalRedirect string) string | ||||
| 	RefreshSessionIfNeeded(*SessionState) (bool, error) | ||||
| 	SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) | ||||
| 	CookieForSession(*SessionState, *cookie.Cipher) (string, error) | ||||
| } | ||||
| 
 | ||||
| func New(provider string, p *ProviderData) Provider { | ||||
|  |  | |||
|  | @ -0,0 +1,115 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/bitly/oauth2_proxy/cookie" | ||||
| ) | ||||
| 
 | ||||
| type SessionState struct { | ||||
| 	AccessToken  string | ||||
| 	ExpiresOn    time.Time | ||||
| 	RefreshToken string | ||||
| 	Email        string | ||||
| 	User         string | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) IsExpired() bool { | ||||
| 	if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) String() string { | ||||
| 	o := fmt.Sprintf("Session{%s", s.userOrEmail()) | ||||
| 	if s.AccessToken != "" { | ||||
| 		o += " token:true" | ||||
| 	} | ||||
| 	if !s.ExpiresOn.IsZero() { | ||||
| 		o += fmt.Sprintf(" expires:%s", s.ExpiresOn) | ||||
| 	} | ||||
| 	if s.RefreshToken != "" { | ||||
| 		o += " refresh_token:true" | ||||
| 	} | ||||
| 	return o + "}" | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { | ||||
| 	if c == nil || s.AccessToken == "" { | ||||
| 		return s.userOrEmail(), nil | ||||
| 	} | ||||
| 	return s.EncryptedString(c) | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) userOrEmail() string { | ||||
| 	u := s.User | ||||
| 	if s.Email != "" { | ||||
| 		u = s.Email | ||||
| 	} | ||||
| 	return u | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { | ||||
| 	var err error | ||||
| 	if c == nil { | ||||
| 		panic("error. missing cipher") | ||||
| 	} | ||||
| 	a := s.AccessToken | ||||
| 	if a != "" { | ||||
| 		a, err = c.Encrypt(a) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	r := s.RefreshToken | ||||
| 	if r != "" { | ||||
| 		r, err = c.Encrypt(r) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil | ||||
| } | ||||
| 
 | ||||
| func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { | ||||
| 	chunks := strings.Split(v, "|") | ||||
| 	if len(chunks) == 1 { | ||||
| 		if strings.Contains(chunks[0], "@") { | ||||
| 			u := strings.Split(v, "@")[0] | ||||
| 			return &SessionState{Email: v, User: u}, nil | ||||
| 		} | ||||
| 		return &SessionState{User: v}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	if len(chunks) != 4 { | ||||
| 		err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	s = &SessionState{} | ||||
| 	if c != nil && chunks[1] != "" { | ||||
| 		s.AccessToken, err = c.Decrypt(chunks[1]) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 	if c != nil && chunks[3] != "" { | ||||
| 		s.RefreshToken, err = c.Decrypt(chunks[3]) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 	if u := chunks[0]; strings.Contains(u, "@") { | ||||
| 		s.Email = u | ||||
| 		s.User = strings.Split(u, "@")[0] | ||||
| 	} else { | ||||
| 		s.User = u | ||||
| 	} | ||||
| 	ts, _ := strconv.Atoi(chunks[2]) | ||||
| 	s.ExpiresOn = time.Unix(int64(ts), 0) | ||||
| 	return | ||||
| } | ||||
|  | @ -0,0 +1,88 @@ | |||
| package providers | ||||
| 
 | ||||
| import ( | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/bitly/oauth2_proxy/cookie" | ||||
| 	"github.com/bmizerany/assert" | ||||
| ) | ||||
| 
 | ||||
| const secret = "0123456789abcdefghijklmnopqrstuv" | ||||
| const altSecret = "0000000000abcdefghijklmnopqrstuv" | ||||
| 
 | ||||
| func TestSessionStateSerialization(t *testing.T) { | ||||
| 	c, err := cookie.NewCipher(secret) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	c2, err := cookie.NewCipher(altSecret) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	s := &SessionState{ | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||
| 		RefreshToken: "refresh4321", | ||||
| 	} | ||||
| 	encoded, err := s.EncodeSessionState(c) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, 3, strings.Count(encoded, "|")) | ||||
| 
 | ||||
| 	ss, err := DecodeSessionState(encoded, c) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, s.AccessToken, ss.AccessToken) | ||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||
| 
 | ||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | ||||
| 	ss, err = DecodeSessionState(encoded, c2) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||
| 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | ||||
| 	assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) | ||||
| } | ||||
| 
 | ||||
| func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||
| 
 | ||||
| 	s := &SessionState{ | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||
| 		RefreshToken: "refresh4321", | ||||
| 	} | ||||
| 	encoded, err := s.EncodeSessionState(nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, s.Email, encoded) | ||||
| 
 | ||||
| 	// only email should have been serialized
 | ||||
| 	ss, err := DecodeSessionState(encoded, nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, "", ss.AccessToken) | ||||
| 	assert.Equal(t, "", ss.RefreshToken) | ||||
| } | ||||
| 
 | ||||
| func TestSessionStateUserOrEmail(t *testing.T) { | ||||
| 
 | ||||
| 	s := &SessionState{ | ||||
| 		Email: "user@domain.com", | ||||
| 		User:  "just-user", | ||||
| 	} | ||||
| 	assert.Equal(t, "user@domain.com", s.userOrEmail()) | ||||
| 	s.Email = "" | ||||
| 	assert.Equal(t, "just-user", s.userOrEmail()) | ||||
| } | ||||
| 
 | ||||
| func TestExpired(t *testing.T) { | ||||
| 	s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} | ||||
| 	assert.Equal(t, true, s.IsExpired()) | ||||
| 
 | ||||
| 	s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} | ||||
| 	assert.Equal(t, false, s.IsExpired()) | ||||
| 
 | ||||
| 	s = &SessionState{} | ||||
| 	assert.Equal(t, false, s.IsExpired()) | ||||
| } | ||||
		Loading…
	
		Reference in New Issue