SessionState refactoring; improve token renewal and cookie refresh
* New SessionState to consolidate email, access token and refresh token * split ServeHttp into individual methods * log on session renewal * log on access token refresh * refactor cookie encription/decription and session state serialization
This commit is contained in:
		
							parent
							
								
									b9ae5dc8d7
								
							
						
					
					
						commit
						d49c3e167f
					
				| 
						 | 
					@ -3,6 +3,7 @@ package api
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/bitly/go-simplejson"
 | 
						"github.com/bitly/go-simplejson"
 | 
				
			||||||
| 
						 | 
					@ -11,10 +12,12 @@ import (
 | 
				
			||||||
func Request(req *http.Request) (*simplejson.Json, error) {
 | 
					func Request(req *http.Request) (*simplejson.Json, error) {
 | 
				
			||||||
	resp, err := http.DefaultClient.Do(req)
 | 
						resp, err := http.DefaultClient.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Printf("%s %s %s", req.Method, req.URL, err)
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	body, err := ioutil.ReadAll(resp.Body)
 | 
						body, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
	resp.Body.Close()
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,128 @@
 | 
				
			||||||
 | 
					package cookie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/aes"
 | 
				
			||||||
 | 
						"crypto/cipher"
 | 
				
			||||||
 | 
						"crypto/hmac"
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"crypto/sha1"
 | 
				
			||||||
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
 | 
				
			||||||
 | 
					// additionally, the 'value' is encrypted so it's opaque to the browser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Validate ensures a cookie is properly signed
 | 
				
			||||||
 | 
					func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) {
 | 
				
			||||||
 | 
						// value, timestamp, sig
 | 
				
			||||||
 | 
						parts := strings.Split(cookie.Value, "|")
 | 
				
			||||||
 | 
						if len(parts) != 3 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						sig := cookieSignature(seed, cookie.Name, parts[0], parts[1])
 | 
				
			||||||
 | 
						if checkHmac(parts[2], sig) {
 | 
				
			||||||
 | 
							ts, err := strconv.Atoi(parts[1])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// The expiration timestamp set when the cookie was created
 | 
				
			||||||
 | 
							// isn't sent back by the browser. Hence, we check whether the
 | 
				
			||||||
 | 
							// creation timestamp stored in the cookie falls within the
 | 
				
			||||||
 | 
							// window defined by (Now()-expiration, Now()].
 | 
				
			||||||
 | 
							t = time.Unix(int64(ts), 0)
 | 
				
			||||||
 | 
							if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
 | 
				
			||||||
 | 
								// it's a valid cookie. now get the contents
 | 
				
			||||||
 | 
								rawValue, err := base64.URLEncoding.DecodeString(parts[0])
 | 
				
			||||||
 | 
								if err == nil {
 | 
				
			||||||
 | 
									value = string(rawValue)
 | 
				
			||||||
 | 
									ok = true
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SignedValue returns a cookie that is signed and can later be checked with Validate
 | 
				
			||||||
 | 
					func SignedValue(seed string, key string, value string, now time.Time) string {
 | 
				
			||||||
 | 
						encodedValue := base64.URLEncoding.EncodeToString([]byte(value))
 | 
				
			||||||
 | 
						timeStr := fmt.Sprintf("%d", now.Unix())
 | 
				
			||||||
 | 
						sig := cookieSignature(seed, key, encodedValue, timeStr)
 | 
				
			||||||
 | 
						cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
 | 
				
			||||||
 | 
						return cookieVal
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func cookieSignature(args ...string) string {
 | 
				
			||||||
 | 
						h := hmac.New(sha1.New, []byte(args[0]))
 | 
				
			||||||
 | 
						for _, arg := range args[1:] {
 | 
				
			||||||
 | 
							h.Write([]byte(arg))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var b []byte
 | 
				
			||||||
 | 
						b = h.Sum(b)
 | 
				
			||||||
 | 
						return base64.URLEncoding.EncodeToString(b)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func checkHmac(input, expected string) bool {
 | 
				
			||||||
 | 
						inputMAC, err1 := base64.URLEncoding.DecodeString(input)
 | 
				
			||||||
 | 
						if err1 == nil {
 | 
				
			||||||
 | 
							expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
 | 
				
			||||||
 | 
							if err2 == nil {
 | 
				
			||||||
 | 
								return hmac.Equal(inputMAC, expectedMAC)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Cipher provides methods to encrypt and decrypt cookie values
 | 
				
			||||||
 | 
					type Cipher struct {
 | 
				
			||||||
 | 
						cipher.Block
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewCipher returns a new aes Cipher for encrypting cookie values
 | 
				
			||||||
 | 
					func NewCipher(secret string) (*Cipher, error) {
 | 
				
			||||||
 | 
						c, err := aes.NewCipher([]byte(secret))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &Cipher{Block: c}, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Encrypt a value for use in a cookie
 | 
				
			||||||
 | 
					func (c *Cipher) Encrypt(value string) (string, error) {
 | 
				
			||||||
 | 
						ciphertext := make([]byte, aes.BlockSize+len(value))
 | 
				
			||||||
 | 
						iv := ciphertext[:aes.BlockSize]
 | 
				
			||||||
 | 
						if _, err := io.ReadFull(rand.Reader, iv); err != nil {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("failed to create initialization vector %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						stream := cipher.NewCFBEncrypter(c.Block, iv)
 | 
				
			||||||
 | 
						stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value))
 | 
				
			||||||
 | 
						return base64.StdEncoding.EncodeToString(ciphertext), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Decrypt a value from a cookie to it's original string
 | 
				
			||||||
 | 
					func (c *Cipher) Decrypt(s string) (string, error) {
 | 
				
			||||||
 | 
						encrypted, err := base64.StdEncoding.DecodeString(s)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("failed to decrypt cookie value %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(encrypted) < aes.BlockSize {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("encrypted cookie value should be "+
 | 
				
			||||||
 | 
								"at least %d bytes, but is only %d bytes",
 | 
				
			||||||
 | 
								aes.BlockSize, len(encrypted))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						iv := encrypted[:aes.BlockSize]
 | 
				
			||||||
 | 
						encrypted = encrypted[aes.BlockSize:]
 | 
				
			||||||
 | 
						stream := cipher.NewCFBDecrypter(c.Block, iv)
 | 
				
			||||||
 | 
						stream.XORKeyStream(encrypted, encrypted)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return string(encrypted), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,23 @@
 | 
				
			||||||
 | 
					package cookie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestEncodeAndDecodeAccessToken(t *testing.T) {
 | 
				
			||||||
 | 
						const secret = "0123456789abcdefghijklmnopqrstuv"
 | 
				
			||||||
 | 
						const token = "my access token"
 | 
				
			||||||
 | 
						c, err := NewCipher(secret)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						encoded, err := c.Encrypt(token)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						decoded, err := c.Decrypt(encoded)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.NotEqual(t, token, encoded)
 | 
				
			||||||
 | 
						assert.Equal(t, token, decoded)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										140
									
								
								cookies.go
								
								
								
								
							
							
						
						
									
										140
									
								
								cookies.go
								
								
								
								
							| 
						 | 
					@ -1,140 +0,0 @@
 | 
				
			||||||
package main
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"crypto/aes"
 | 
					 | 
				
			||||||
	"crypto/cipher"
 | 
					 | 
				
			||||||
	"crypto/hmac"
 | 
					 | 
				
			||||||
	"crypto/rand"
 | 
					 | 
				
			||||||
	"crypto/sha1"
 | 
					 | 
				
			||||||
	"encoding/base64"
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"io"
 | 
					 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
	"strconv"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func validateCookie(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) {
 | 
					 | 
				
			||||||
	// value, timestamp, sig
 | 
					 | 
				
			||||||
	parts := strings.Split(cookie.Value, "|")
 | 
					 | 
				
			||||||
	if len(parts) != 3 {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	sig := cookieSignature(seed, cookie.Name, parts[0], parts[1])
 | 
					 | 
				
			||||||
	if checkHmac(parts[2], sig) {
 | 
					 | 
				
			||||||
		ts, err := strconv.Atoi(parts[1])
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		// The expiration timestamp set when the cookie was created
 | 
					 | 
				
			||||||
		// isn't sent back by the browser. Hence, we check whether the
 | 
					 | 
				
			||||||
		// creation timestamp stored in the cookie falls within the
 | 
					 | 
				
			||||||
		// window defined by (Now()-expiration, Now()].
 | 
					 | 
				
			||||||
		t = time.Unix(int64(ts), 0)
 | 
					 | 
				
			||||||
		if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
 | 
					 | 
				
			||||||
			// it's a valid cookie. now get the contents
 | 
					 | 
				
			||||||
			rawValue, err := base64.URLEncoding.DecodeString(parts[0])
 | 
					 | 
				
			||||||
			if err == nil {
 | 
					 | 
				
			||||||
				value = string(rawValue)
 | 
					 | 
				
			||||||
				ok = true
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func signedCookieValue(seed string, key string, value string, now time.Time) string {
 | 
					 | 
				
			||||||
	encodedValue := base64.URLEncoding.EncodeToString([]byte(value))
 | 
					 | 
				
			||||||
	timeStr := fmt.Sprintf("%d", now.Unix())
 | 
					 | 
				
			||||||
	sig := cookieSignature(seed, key, encodedValue, timeStr)
 | 
					 | 
				
			||||||
	cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
 | 
					 | 
				
			||||||
	return cookieVal
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func cookieSignature(args ...string) string {
 | 
					 | 
				
			||||||
	h := hmac.New(sha1.New, []byte(args[0]))
 | 
					 | 
				
			||||||
	for _, arg := range args[1:] {
 | 
					 | 
				
			||||||
		h.Write([]byte(arg))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	var b []byte
 | 
					 | 
				
			||||||
	b = h.Sum(b)
 | 
					 | 
				
			||||||
	return base64.URLEncoding.EncodeToString(b)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func checkHmac(input, expected string) bool {
 | 
					 | 
				
			||||||
	inputMAC, err1 := base64.URLEncoding.DecodeString(input)
 | 
					 | 
				
			||||||
	if err1 == nil {
 | 
					 | 
				
			||||||
		expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
 | 
					 | 
				
			||||||
		if err2 == nil {
 | 
					 | 
				
			||||||
			return hmac.Equal(inputMAC, expectedMAC)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func encodeAccessToken(aes_cipher cipher.Block, access_token string) (string, error) {
 | 
					 | 
				
			||||||
	ciphertext := make([]byte, aes.BlockSize+len(access_token))
 | 
					 | 
				
			||||||
	iv := ciphertext[:aes.BlockSize]
 | 
					 | 
				
			||||||
	if _, err := io.ReadFull(rand.Reader, iv); err != nil {
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("failed to create access code initialization vector")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	stream := cipher.NewCFBEncrypter(aes_cipher, iv)
 | 
					 | 
				
			||||||
	stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(access_token))
 | 
					 | 
				
			||||||
	return base64.StdEncoding.EncodeToString(ciphertext), nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (string, error) {
 | 
					 | 
				
			||||||
	encrypted_access_token, err := base64.StdEncoding.DecodeString(
 | 
					 | 
				
			||||||
		encoded_access_token)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("failed to decode access token")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if len(encrypted_access_token) < aes.BlockSize {
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("encrypted access token should be "+
 | 
					 | 
				
			||||||
			"at least %d bytes, but is only %d bytes",
 | 
					 | 
				
			||||||
			aes.BlockSize, len(encrypted_access_token))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	iv := encrypted_access_token[:aes.BlockSize]
 | 
					 | 
				
			||||||
	encrypted_access_token = encrypted_access_token[aes.BlockSize:]
 | 
					 | 
				
			||||||
	stream := cipher.NewCFBDecrypter(aes_cipher, iv)
 | 
					 | 
				
			||||||
	stream.XORKeyStream(encrypted_access_token, encrypted_access_token)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return string(encrypted_access_token), nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func buildCookieValue(email string, aes_cipher cipher.Block,
 | 
					 | 
				
			||||||
	access_token string) (string, error) {
 | 
					 | 
				
			||||||
	if aes_cipher == nil {
 | 
					 | 
				
			||||||
		return email, nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	encoded_token, err := encodeAccessToken(aes_cipher, access_token)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return email, fmt.Errorf(
 | 
					 | 
				
			||||||
			"error encoding access token for %s: %s", email, err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return email + "|" + encoded_token, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func parseCookieValue(value string, aes_cipher cipher.Block) (email, user,
 | 
					 | 
				
			||||||
	access_token string, err error) {
 | 
					 | 
				
			||||||
	components := strings.Split(value, "|")
 | 
					 | 
				
			||||||
	email = components[0]
 | 
					 | 
				
			||||||
	user = strings.Split(email, "@")[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if aes_cipher != nil && len(components) == 2 {
 | 
					 | 
				
			||||||
		access_token, err = decodeAccessToken(aes_cipher, components[1])
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			err = fmt.Errorf(
 | 
					 | 
				
			||||||
				"error decoding access token for %s: %s",
 | 
					 | 
				
			||||||
				email, err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return email, user, access_token, err
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,75 +0,0 @@
 | 
				
			||||||
package main
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"crypto/aes"
 | 
					 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestEncodeAndDecodeAccessToken(t *testing.T) {
 | 
					 | 
				
			||||||
	const key = "0123456789abcdefghijklmnopqrstuv"
 | 
					 | 
				
			||||||
	const access_token = "my access token"
 | 
					 | 
				
			||||||
	c, err := aes.NewCipher([]byte(key))
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	encoded_token, err := encodeAccessToken(c, access_token)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	decoded_token, err := decodeAccessToken(c, encoded_token)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	assert.NotEqual(t, access_token, encoded_token)
 | 
					 | 
				
			||||||
	assert.Equal(t, access_token, decoded_token)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestBuildCookieValueWithoutAccessToken(t *testing.T) {
 | 
					 | 
				
			||||||
	value, err := buildCookieValue("michael.bland@gsa.gov", nil, "")
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", value)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestBuildCookieValueWithAccessTokenAndNilCipher(t *testing.T) {
 | 
					 | 
				
			||||||
	value, err := buildCookieValue("michael.bland@gsa.gov", nil,
 | 
					 | 
				
			||||||
		"access token")
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", value)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestParseCookieValueWithoutAccessToken(t *testing.T) {
 | 
					 | 
				
			||||||
	email, user, access_token, err := parseCookieValue(
 | 
					 | 
				
			||||||
		"michael.bland@gsa.gov", nil)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland", user)
 | 
					 | 
				
			||||||
	assert.Equal(t, "", access_token)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestParseCookieValueWithAccessTokenAndNilCipher(t *testing.T) {
 | 
					 | 
				
			||||||
	email, user, access_token, err := parseCookieValue(
 | 
					 | 
				
			||||||
		"michael.bland@gsa.gov|access_token", nil)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland", user)
 | 
					 | 
				
			||||||
	assert.Equal(t, "", access_token)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestBuildAndParseCookieValueWithAccessToken(t *testing.T) {
 | 
					 | 
				
			||||||
	aes_cipher, err := aes.NewCipher([]byte("0123456789abcdef"))
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
	value, err := buildCookieValue("michael.bland@gsa.gov", aes_cipher,
 | 
					 | 
				
			||||||
		"access_token")
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	prefix := "michael.bland@gsa.gov|"
 | 
					 | 
				
			||||||
	if !strings.HasPrefix(value, prefix) {
 | 
					 | 
				
			||||||
		t.Fatal("cookie value does not start with \"%s\": %s",
 | 
					 | 
				
			||||||
			prefix, value)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	email, user, access_token, err := parseCookieValue(value, aes_cipher)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland", user)
 | 
					 | 
				
			||||||
	assert.Equal(t, "access_token", access_token)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										283
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										283
									
								
								oauthproxy.go
								
								
								
								
							| 
						 | 
					@ -1,8 +1,6 @@
 | 
				
			||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"crypto/aes"
 | 
					 | 
				
			||||||
	"crypto/cipher"
 | 
					 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
| 
						 | 
					@ -16,6 +14,7 @@ import (
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bitly/oauth2_proxy/cookie"
 | 
				
			||||||
	"github.com/bitly/oauth2_proxy/providers"
 | 
						"github.com/bitly/oauth2_proxy/providers"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,7 +43,7 @@ type OauthProxy struct {
 | 
				
			||||||
	serveMux            http.Handler
 | 
						serveMux            http.Handler
 | 
				
			||||||
	PassBasicAuth       bool
 | 
						PassBasicAuth       bool
 | 
				
			||||||
	PassAccessToken     bool
 | 
						PassAccessToken     bool
 | 
				
			||||||
	AesCipher           cipher.Block
 | 
						CookieCipher        *cookie.Cipher
 | 
				
			||||||
	skipAuthRegex       []string
 | 
						skipAuthRegex       []string
 | 
				
			||||||
	compiledRegex       []*regexp.Regexp
 | 
						compiledRegex       []*regexp.Regexp
 | 
				
			||||||
	templates           *template.Template
 | 
						templates           *template.Template
 | 
				
			||||||
| 
						 | 
					@ -116,10 +115,10 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain, refresh)
 | 
						log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain, refresh)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var aes_cipher cipher.Block
 | 
						var cipher *cookie.Cipher
 | 
				
			||||||
	if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
 | 
						if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
 | 
				
			||||||
		var err error
 | 
							var err error
 | 
				
			||||||
		aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret))
 | 
							cipher, err = cookie.NewCipher(opts.CookieSecret)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatal("error creating AES cipher with "+
 | 
								log.Fatal("error creating AES cipher with "+
 | 
				
			||||||
				"cookie-secret ", opts.CookieSecret, ": ", err)
 | 
									"cookie-secret ", opts.CookieSecret, ": ", err)
 | 
				
			||||||
| 
						 | 
					@ -150,7 +149,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
 | 
				
			||||||
		compiledRegex:   opts.CompiledRegex,
 | 
							compiledRegex:   opts.CompiledRegex,
 | 
				
			||||||
		PassBasicAuth:   opts.PassBasicAuth,
 | 
							PassBasicAuth:   opts.PassBasicAuth,
 | 
				
			||||||
		PassAccessToken: opts.PassAccessToken,
 | 
							PassAccessToken: opts.PassAccessToken,
 | 
				
			||||||
		AesCipher:       aes_cipher,
 | 
							CookieCipher:    cipher,
 | 
				
			||||||
		templates:       loadTemplates(opts.CustomTemplatesDir),
 | 
							templates:       loadTemplates(opts.CustomTemplatesDir),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -177,22 +176,20 @@ func (p *OauthProxy) displayCustomLoginForm() bool {
 | 
				
			||||||
	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
 | 
						return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OauthProxy) redeemCode(host, code string) (string, string, error) {
 | 
					func (p *OauthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
 | 
				
			||||||
	if code == "" {
 | 
						if code == "" {
 | 
				
			||||||
		return "", "", errors.New("missing code")
 | 
							return nil, errors.New("missing code")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	redirectUri := p.GetRedirectURI(host)
 | 
						redirectUri := p.GetRedirectURI(host)
 | 
				
			||||||
	body, access_token, err := p.provider.Redeem(redirectUri, code)
 | 
						s, err = p.provider.Redeem(redirectUri, code)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", "", err
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.provider.GetEmailAddress(body, access_token)
 | 
						if s.Email == "" {
 | 
				
			||||||
	if err != nil {
 | 
							s.Email, err = p.provider.GetEmailAddress(s)
 | 
				
			||||||
		return "", "", err
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
	return access_token, email, nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
 | 
					func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
 | 
				
			||||||
| 
						 | 
					@ -208,9 +205,8 @@ func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if value != "" {
 | 
						if value != "" {
 | 
				
			||||||
		value = signedCookieValue(p.CookieSeed, p.CookieName, value, now)
 | 
							value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	return &http.Cookie{
 | 
						return &http.Cookie{
 | 
				
			||||||
		Name:     p.CookieName,
 | 
							Name:     p.CookieName,
 | 
				
			||||||
		Value:    value,
 | 
							Value:    value,
 | 
				
			||||||
| 
						 | 
					@ -230,35 +226,34 @@ func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val st
 | 
				
			||||||
	http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
 | 
						http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) {
 | 
					func (p *OauthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
 | 
				
			||||||
	var value string
 | 
						var age time.Duration
 | 
				
			||||||
	var timestamp time.Time
 | 
						c, err := req.Cookie(p.CookieName)
 | 
				
			||||||
	cookie, err := req.Cookie(p.CookieName)
 | 
					 | 
				
			||||||
	if err == nil {
 | 
					 | 
				
			||||||
		value, timestamp, ok = validateCookie(cookie, p.CookieSeed, p.CookieExpire)
 | 
					 | 
				
			||||||
		if ok {
 | 
					 | 
				
			||||||
			email, user, access_token, err = parseCookieValue(value, p.AesCipher)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Printf(err.Error())
 | 
							// always http.ErrNoCookie
 | 
				
			||||||
		ok = false
 | 
							return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
 | 
				
			||||||
	} else if ok && p.CookieRefresh != time.Duration(0) {
 | 
						}
 | 
				
			||||||
		refresh := timestamp.Add(p.CookieRefresh)
 | 
						val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire)
 | 
				
			||||||
		if refresh.Before(time.Now()) {
 | 
						if !ok {
 | 
				
			||||||
			log.Printf("refreshing %s old session for %s (refresh after %s)", time.Now().Sub(timestamp), email, p.CookieRefresh)
 | 
							return nil, age, errors.New("Cookie Signature not valid")
 | 
				
			||||||
			ok = p.Validator(email)
 | 
						}
 | 
				
			||||||
			log.Printf("re-validating %s valid:%v", email, ok)
 | 
					
 | 
				
			||||||
			if ok {
 | 
						session, err := p.provider.SessionFromCookie(val, p.CookieCipher)
 | 
				
			||||||
				ok = p.provider.ValidateToken(access_token)
 | 
						if err != nil {
 | 
				
			||||||
				log.Printf("re-validating access token. valid:%v", ok)
 | 
							return nil, age, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						age = time.Now().Truncate(time.Second).Sub(timestamp)
 | 
				
			||||||
 | 
						return session, age, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *OauthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
 | 
				
			||||||
 | 
						value, err := p.provider.CookieForSession(s, p.CookieCipher)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
			if ok {
 | 
					 | 
				
			||||||
	p.SetCookie(rw, req, value)
 | 
						p.SetCookie(rw, req, value)
 | 
				
			||||||
			}
 | 
						return nil
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) {
 | 
					func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) {
 | 
				
			||||||
| 
						 | 
					@ -344,54 +339,61 @@ func (p *OauthProxy) GetRedirect(req *http.Request) (string, error) {
 | 
				
			||||||
	return redirect, err
 | 
						return redirect, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 | 
					func (p *OauthProxy) IsWhitelistedPath(path string) (ok bool) {
 | 
				
			||||||
	// check if this is a redirect back at the end of oauth
 | 
					 | 
				
			||||||
	remoteAddr := req.RemoteAddr
 | 
					 | 
				
			||||||
	if req.Header.Get("X-Real-IP") != "" {
 | 
					 | 
				
			||||||
		remoteAddr += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var ok bool
 | 
					 | 
				
			||||||
	var user string
 | 
					 | 
				
			||||||
	var email string
 | 
					 | 
				
			||||||
	var access_token string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if req.URL.Path == p.RobotsPath {
 | 
					 | 
				
			||||||
		p.RobotsTxt(rw)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if req.URL.Path == p.PingPath {
 | 
					 | 
				
			||||||
		p.PingPage(rw)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, u := range p.compiledRegex {
 | 
						for _, u := range p.compiledRegex {
 | 
				
			||||||
		match := u.MatchString(req.URL.Path)
 | 
							ok = u.MatchString(path)
 | 
				
			||||||
		if match {
 | 
							if ok {
 | 
				
			||||||
			p.serveMux.ServeHTTP(rw, req)
 | 
					 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if req.URL.Path == p.SignInPath {
 | 
					func getRemoteAddr(req *http.Request) (s string) {
 | 
				
			||||||
 | 
						s = req.RemoteAddr
 | 
				
			||||||
 | 
						if req.Header.Get("X-Real-IP") != "" {
 | 
				
			||||||
 | 
							s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
						switch path := req.URL.Path; {
 | 
				
			||||||
 | 
						case path == p.RobotsPath:
 | 
				
			||||||
 | 
							p.RobotsTxt(rw)
 | 
				
			||||||
 | 
						case path == p.PingPath:
 | 
				
			||||||
 | 
							p.PingPage(rw)
 | 
				
			||||||
 | 
						case p.IsWhitelistedPath(path):
 | 
				
			||||||
 | 
							p.serveMux.ServeHTTP(rw, req)
 | 
				
			||||||
 | 
						case path == p.SignInPath:
 | 
				
			||||||
 | 
							p.SignIn(rw, req)
 | 
				
			||||||
 | 
						case path == p.OauthStartPath:
 | 
				
			||||||
 | 
							p.OauthStart(rw, req)
 | 
				
			||||||
 | 
						case path == p.OauthCallbackPath:
 | 
				
			||||||
 | 
							p.OauthCallback(rw, req)
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							p.Proxy(rw, req)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *OauthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
	redirect, err := p.GetRedirect(req)
 | 
						redirect, err := p.GetRedirect(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		p.ErrorPage(rw, 500, "Internal Error", err.Error())
 | 
							p.ErrorPage(rw, 500, "Internal Error", err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		user, ok = p.ManualSignIn(rw, req)
 | 
						user, ok := p.ManualSignIn(rw, req)
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
			p.SetCookie(rw, req, user)
 | 
							session := &providers.SessionState{User: user}
 | 
				
			||||||
 | 
							p.SaveSession(rw, req, session)
 | 
				
			||||||
		http.Redirect(rw, req, redirect, 302)
 | 
							http.Redirect(rw, req, redirect, 302)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		p.SignInPage(rw, req, 200)
 | 
							p.SignInPage(rw, req, 200)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
		return
 | 
					}
 | 
				
			||||||
	}
 | 
					
 | 
				
			||||||
	if req.URL.Path == p.OauthStartPath {
 | 
					func (p *OauthProxy) OauthStart(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
	redirect, err := p.GetRedirect(req)
 | 
						redirect, err := p.GetRedirect(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		p.ErrorPage(rw, 500, "Internal Error", err.Error())
 | 
							p.ErrorPage(rw, 500, "Internal Error", err.Error())
 | 
				
			||||||
| 
						 | 
					@ -399,9 +401,11 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	redirectURI := p.GetRedirectURI(req.Host)
 | 
						redirectURI := p.GetRedirectURI(req.Host)
 | 
				
			||||||
	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
 | 
						http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
 | 
				
			||||||
		return
 | 
					}
 | 
				
			||||||
	}
 | 
					
 | 
				
			||||||
	if req.URL.Path == p.OauthCallbackPath {
 | 
					func (p *OauthProxy) OauthCallback(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
						remoteAddr := getRemoteAddr(req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// finish the oauth cycle
 | 
						// finish the oauth cycle
 | 
				
			||||||
	err := req.ParseForm()
 | 
						err := req.ParseForm()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -414,10 +418,10 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		access_token, email, err = p.redeemCode(req.Host, req.Form.Get("code"))
 | 
						session, err := p.redeemCode(req.Host, req.Form.Get("code"))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Printf("%s error redeeming code %s", remoteAddr, err)
 | 
							log.Printf("%s error redeeming code %s", remoteAddr, err)
 | 
				
			||||||
			p.ErrorPage(rw, 500, "Internal Error", err.Error())
 | 
							p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -427,73 +431,134 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// set cookie, or deny
 | 
						// set cookie, or deny
 | 
				
			||||||
		if p.Validator(email) {
 | 
						if p.Validator(session.Email) {
 | 
				
			||||||
			log.Printf("%s authenticating %s completed", remoteAddr, email)
 | 
							log.Printf("%s authentication complete %s", remoteAddr, session)
 | 
				
			||||||
			value, err := buildCookieValue(
 | 
							err := p.SaveSession(rw, req, session)
 | 
				
			||||||
				email, p.AesCipher, access_token)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
				log.Printf("%s", err)
 | 
								log.Printf("%s %s", remoteAddr, err)
 | 
				
			||||||
 | 
								p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
			p.SetCookie(rw, req, value)
 | 
					 | 
				
			||||||
		http.Redirect(rw, req, redirect, 302)
 | 
							http.Redirect(rw, req, redirect, 302)
 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
			log.Printf("validating: %s is unauthorized")
 | 
							log.Printf("%s Permission Denied: %q is unauthorized", remoteAddr, session.Email)
 | 
				
			||||||
		p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
 | 
							p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
						var saveSession, clearSession, revalidated bool
 | 
				
			||||||
 | 
						remoteAddr := getRemoteAddr(req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session, sessionAge, err := p.LoadCookiedSession(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Printf("%s %s", remoteAddr, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
 | 
				
			||||||
 | 
							log.Printf("%s refreshing %s old session cookie for %s (refresh after %s)", remoteAddr, sessionAge, session, p.CookieRefresh)
 | 
				
			||||||
 | 
							saveSession = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil {
 | 
				
			||||||
 | 
							log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
 | 
				
			||||||
 | 
							clearSession = true
 | 
				
			||||||
 | 
							session = nil
 | 
				
			||||||
 | 
						} else if ok {
 | 
				
			||||||
 | 
							saveSession = true
 | 
				
			||||||
 | 
							revalidated = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if session != nil && session.IsExpired() {
 | 
				
			||||||
 | 
							log.Printf("%s removing session. token expired %s", remoteAddr, session)
 | 
				
			||||||
 | 
							session = nil
 | 
				
			||||||
 | 
							saveSession = false
 | 
				
			||||||
 | 
							clearSession = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if saveSession && !revalidated && session.AccessToken != "" {
 | 
				
			||||||
 | 
							if !p.provider.ValidateSessionState(session) {
 | 
				
			||||||
 | 
								log.Printf("%s removing session. error validating %s", remoteAddr, session)
 | 
				
			||||||
 | 
								saveSession = false
 | 
				
			||||||
 | 
								session = nil
 | 
				
			||||||
 | 
								clearSession = true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if saveSession && session.Email != "" && !p.Validator(session.Email) {
 | 
				
			||||||
 | 
							log.Printf("%s Permission Denied: removing session %s", remoteAddr, session)
 | 
				
			||||||
 | 
							session = nil
 | 
				
			||||||
 | 
							saveSession = false
 | 
				
			||||||
 | 
							clearSession = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if saveSession {
 | 
				
			||||||
 | 
							err := p.SaveSession(rw, req, session)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Printf("%s %s", remoteAddr, err)
 | 
				
			||||||
 | 
								p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !ok {
 | 
						if clearSession {
 | 
				
			||||||
		email, user, access_token, ok = p.ProcessCookie(rw, req)
 | 
							p.ClearCookie(rw, req)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !ok {
 | 
						if session == nil {
 | 
				
			||||||
		user, ok = p.CheckBasicAuth(req)
 | 
							session, err = p.CheckBasicAuth(req)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Printf("%s %s", remoteAddr, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !ok {
 | 
						if session == nil {
 | 
				
			||||||
		p.SignInPage(rw, req, 403)
 | 
							p.SignInPage(rw, req, 403)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// At this point, the user is authenticated. proxy normally
 | 
						// At this point, the user is authenticated. proxy normally
 | 
				
			||||||
	if p.PassBasicAuth {
 | 
						if p.PassBasicAuth {
 | 
				
			||||||
		req.SetBasicAuth(user, "")
 | 
							req.SetBasicAuth(session.User, "")
 | 
				
			||||||
		req.Header["X-Forwarded-User"] = []string{user}
 | 
							req.Header["X-Forwarded-User"] = []string{session.User}
 | 
				
			||||||
		req.Header["X-Forwarded-Email"] = []string{email}
 | 
							if session.Email != "" {
 | 
				
			||||||
 | 
								req.Header["X-Forwarded-Email"] = []string{session.Email}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	if p.PassAccessToken {
 | 
					 | 
				
			||||||
		req.Header["X-Forwarded-Access-Token"] = []string{access_token}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if email == "" {
 | 
						if p.PassAccessToken && session.AccessToken != "" {
 | 
				
			||||||
		rw.Header().Set("GAP-Auth", user)
 | 
							req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if session.Email == "" {
 | 
				
			||||||
 | 
							rw.Header().Set("GAP-Auth", session.User)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		rw.Header().Set("GAP-Auth", email)
 | 
							rw.Header().Set("GAP-Auth", session.Email)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p.serveMux.ServeHTTP(rw, req)
 | 
						p.serveMux.ServeHTTP(rw, req)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OauthProxy) CheckBasicAuth(req *http.Request) (string, bool) {
 | 
					func (p *OauthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
 | 
				
			||||||
	if p.HtpasswdFile == nil {
 | 
						if p.HtpasswdFile == nil {
 | 
				
			||||||
		return "", false
 | 
							return nil, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	s := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
 | 
						auth := req.Header.Get("Authorization")
 | 
				
			||||||
 | 
						if auth == "" {
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s := strings.SplitN(auth, " ", 2)
 | 
				
			||||||
	if len(s) != 2 || s[0] != "Basic" {
 | 
						if len(s) != 2 || s[0] != "Basic" {
 | 
				
			||||||
		return "", false
 | 
							return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization"))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	b, err := base64.StdEncoding.DecodeString(s[1])
 | 
						b, err := base64.StdEncoding.DecodeString(s[1])
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", false
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	pair := strings.SplitN(string(b), ":", 2)
 | 
						pair := strings.SplitN(string(b), ":", 2)
 | 
				
			||||||
	if len(pair) != 2 {
 | 
						if len(pair) != 2 {
 | 
				
			||||||
		return "", false
 | 
							return nil, fmt.Errorf("invalid format %s", b)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if p.HtpasswdFile.Validate(pair[0], pair[1]) {
 | 
						if p.HtpasswdFile.Validate(pair[0], pair[1]) {
 | 
				
			||||||
		log.Printf("authenticated %q via basic auth", pair[0])
 | 
							log.Printf("authenticated %q via basic auth", pair[0])
 | 
				
			||||||
		return pair[0], true
 | 
							return &providers.SessionState{User: pair[0]}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return "", false
 | 
						return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0])
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -94,11 +94,11 @@ type TestProvider struct {
 | 
				
			||||||
	ValidToken   bool
 | 
						ValidToken   bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
					func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
 | 
				
			||||||
	return tp.EmailAddress, nil
 | 
						return tp.EmailAddress, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (tp *TestProvider) ValidateToken(access_token string) bool {
 | 
					func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
 | 
				
			||||||
	return tp.ValidToken
 | 
						return tp.ValidToken
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -378,97 +378,73 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *ProcessCookieTest) MakeCookie(value, access_token string, ref time.Time) *http.Cookie {
 | 
					func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
 | 
				
			||||||
	cookie_value, _ := buildCookieValue(value, p.proxy.AesCipher, access_token)
 | 
						return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref)
 | 
				
			||||||
	return p.proxy.MakeCookie(p.req, cookie_value, p.opts.CookieExpire, ref)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *ProcessCookieTest) AddCookie(value, access_token string) {
 | 
					func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
 | 
				
			||||||
	p.req.AddCookie(p.MakeCookie(value, access_token, time.Now()))
 | 
						value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref))
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *ProcessCookieTest) ProcessCookie() (email, user, access_token string, ok bool) {
 | 
					func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
 | 
				
			||||||
	return p.proxy.ProcessCookie(p.rw, p.req)
 | 
						return p.proxy.LoadCookiedSession(p.req)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestProcessCookie(t *testing.T) {
 | 
					func TestLoadCookiedSession(t *testing.T) {
 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
						pc_test := NewProcessCookieTestWithDefaults()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token")
 | 
						startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
 | 
				
			||||||
	email, user, access_token, ok := pc_test.ProcessCookie()
 | 
						pc_test.SaveSession(startSession, time.Now())
 | 
				
			||||||
	assert.Equal(t, true, ok)
 | 
					
 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
						session, _, err := pc_test.LoadCookiedSession()
 | 
				
			||||||
	assert.Equal(t, "michael.bland", user)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "my_access_token", access_token)
 | 
						assert.Equal(t, startSession.Email, session.Email)
 | 
				
			||||||
 | 
						assert.Equal(t, "michael.bland", session.User)
 | 
				
			||||||
 | 
						assert.Equal(t, startSession.AccessToken, session.AccessToken)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestProcessCookieNoCookieError(t *testing.T) {
 | 
					func TestProcessCookieNoCookieError(t *testing.T) {
 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
						pc_test := NewProcessCookieTestWithDefaults()
 | 
				
			||||||
	_, _, _, ok := pc_test.ProcessCookie()
 | 
					 | 
				
			||||||
	assert.Equal(t, false, ok)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) {
 | 
						session, _, err := pc_test.LoadCookiedSession()
 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
						assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
 | 
				
			||||||
	value, _ := buildCookieValue("michael.bland@gsa.gov",
 | 
						if session != nil {
 | 
				
			||||||
		pc_test.proxy.AesCipher, "my_access_token")
 | 
							t.Errorf("expected nil session. got %#v", session)
 | 
				
			||||||
	pc_test.req.AddCookie(pc_test.proxy.MakeCookie(
 | 
						}
 | 
				
			||||||
		pc_test.req, value+"some bogus bytes",
 | 
					 | 
				
			||||||
		pc_test.opts.CookieExpire, time.Now()))
 | 
					 | 
				
			||||||
	_, _, _, ok := pc_test.ProcessCookie()
 | 
					 | 
				
			||||||
	assert.Equal(t, false, ok)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestProcessCookieRefreshNotSet(t *testing.T) {
 | 
					func TestProcessCookieRefreshNotSet(t *testing.T) {
 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
						pc_test := NewProcessCookieTestWithDefaults()
 | 
				
			||||||
	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
 | 
						pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
 | 
				
			||||||
	reference := time.Now().Add(time.Duration(-2) * time.Hour)
 | 
						reference := time.Now().Add(time.Duration(-2) * time.Hour)
 | 
				
			||||||
	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "", reference)
 | 
					 | 
				
			||||||
	pc_test.req.AddCookie(cookie)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, _, _, ok := pc_test.ProcessCookie()
 | 
						startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
 | 
				
			||||||
	assert.Equal(t, true, ok)
 | 
						pc_test.SaveSession(startSession, reference)
 | 
				
			||||||
	assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestProcessCookieRefresh(t *testing.T) {
 | 
						session, age, err := pc_test.LoadCookiedSession()
 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
 | 
						if age < time.Duration(-2)*time.Hour {
 | 
				
			||||||
	reference := time.Now().Add(time.Duration(-2) * time.Hour)
 | 
							t.Errorf("cookie too young %v", age)
 | 
				
			||||||
	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
 | 
						}
 | 
				
			||||||
	pc_test.req.AddCookie(cookie)
 | 
						assert.Equal(t, startSession.Email, session.Email)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	pc_test.proxy.CookieRefresh = time.Hour
 | 
					 | 
				
			||||||
	_, _, _, ok := pc_test.ProcessCookie()
 | 
					 | 
				
			||||||
	assert.Equal(t, true, ok)
 | 
					 | 
				
			||||||
	assert.NotEqual(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) {
 | 
					 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
					 | 
				
			||||||
	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
 | 
					 | 
				
			||||||
	reference := time.Now().Add(time.Duration(-30) * time.Minute)
 | 
					 | 
				
			||||||
	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
 | 
					 | 
				
			||||||
	pc_test.req.AddCookie(cookie)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	pc_test.proxy.CookieRefresh = time.Hour
 | 
					 | 
				
			||||||
	_, _, _, ok := pc_test.ProcessCookie()
 | 
					 | 
				
			||||||
	assert.Equal(t, true, ok)
 | 
					 | 
				
			||||||
	assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
 | 
					func TestProcessCookieFailIfCookieExpired(t *testing.T) {
 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
						pc_test := NewProcessCookieTestWithDefaults()
 | 
				
			||||||
	pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
 | 
						pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
 | 
				
			||||||
	reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
 | 
						reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
 | 
				
			||||||
	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
 | 
						startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
 | 
				
			||||||
	pc_test.req.AddCookie(cookie)
 | 
						pc_test.SaveSession(startSession, reference)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if _, _, _, ok := pc_test.ProcessCookie(); ok {
 | 
						session, _, err := pc_test.LoadCookiedSession()
 | 
				
			||||||
		t.Error("ProcessCookie() should have failed")
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	}
 | 
						if session != nil {
 | 
				
			||||||
	if set_cookie := pc_test.rw.HeaderMap["Set-Cookie"]; set_cookie != nil {
 | 
							t.Errorf("expected nil session %#v", session)
 | 
				
			||||||
		t.Error("expected Set-Cookie to be nil, instead was: ", set_cookie)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -476,44 +452,13 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
						pc_test := NewProcessCookieTestWithDefaults()
 | 
				
			||||||
	pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
 | 
						pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
 | 
				
			||||||
	reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
 | 
						reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
 | 
				
			||||||
	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
 | 
						startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
 | 
				
			||||||
	pc_test.req.AddCookie(cookie)
 | 
						pc_test.SaveSession(startSession, reference)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pc_test.proxy.CookieRefresh = time.Hour
 | 
						pc_test.proxy.CookieRefresh = time.Hour
 | 
				
			||||||
	if _, _, _, ok := pc_test.ProcessCookie(); ok {
 | 
						session, _, err := pc_test.LoadCookiedSession()
 | 
				
			||||||
		t.Error("ProcessCookie() should have failed")
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	}
 | 
						if session != nil {
 | 
				
			||||||
	if set_cookie := pc_test.rw.HeaderMap["Set-Cookie"]; set_cookie != nil {
 | 
							t.Errorf("expected nil session %#v", session)
 | 
				
			||||||
		t.Error("expected Set-Cookie to be nil, instead was: ", set_cookie)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) {
 | 
					 | 
				
			||||||
	pc_test := NewProcessCookieTest(ProcessCookieTestOpts{
 | 
					 | 
				
			||||||
		provider_validate_cookie_response: false,
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
 | 
					 | 
				
			||||||
	reference := time.Now().Add(time.Duration(-24) * time.Hour)
 | 
					 | 
				
			||||||
	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
 | 
					 | 
				
			||||||
	pc_test.req.AddCookie(cookie)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	pc_test.proxy.CookieRefresh = time.Hour
 | 
					 | 
				
			||||||
	_, _, _, ok := pc_test.ProcessCookie()
 | 
					 | 
				
			||||||
	assert.Equal(t, false, ok)
 | 
					 | 
				
			||||||
	assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) {
 | 
					 | 
				
			||||||
	pc_test := NewProcessCookieTestWithDefaults()
 | 
					 | 
				
			||||||
	pc_test.validate_user = false
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
 | 
					 | 
				
			||||||
	reference := time.Now().Add(time.Duration(-2) * time.Hour)
 | 
					 | 
				
			||||||
	cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
 | 
					 | 
				
			||||||
	pc_test.req.AddCookie(cookie)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	pc_test.proxy.CookieRefresh = time.Hour
 | 
					 | 
				
			||||||
	_, _, _, ok := pc_test.ProcessCookie()
 | 
					 | 
				
			||||||
	assert.Equal(t, false, ok)
 | 
					 | 
				
			||||||
	assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,8 +2,10 @@ package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -138,7 +140,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
 | 
				
			||||||
	return false, nil
 | 
						return false, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
					func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var emails []struct {
 | 
						var emails []struct {
 | 
				
			||||||
		Email   string `json:"email"`
 | 
							Email   string `json:"email"`
 | 
				
			||||||
| 
						 | 
					@ -148,31 +150,34 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri
 | 
				
			||||||
	// if we require an Org or Team, check that first
 | 
						// if we require an Org or Team, check that first
 | 
				
			||||||
	if p.Org != "" {
 | 
						if p.Org != "" {
 | 
				
			||||||
		if p.Team != "" {
 | 
							if p.Team != "" {
 | 
				
			||||||
			if ok, err := p.hasOrgAndTeam(access_token); err != nil || !ok {
 | 
								if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
 | 
				
			||||||
				return "", err
 | 
									return "", err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			if ok, err := p.hasOrg(access_token); err != nil || !ok {
 | 
								if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
 | 
				
			||||||
				return "", err
 | 
									return "", err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	params := url.Values{
 | 
						params := url.Values{
 | 
				
			||||||
		"access_token": {access_token},
 | 
							"access_token": {s.AccessToken},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	endpoint := "https://api.github.com/user/emails?" + params.Encode()
 | 
						endpoint := "https://api.github.com/user/emails?" + params.Encode()
 | 
				
			||||||
	resp, err := http.DefaultClient.Get(endpoint)
 | 
						resp, err := http.DefaultClient.Get(endpoint)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	body, err = ioutil.ReadAll(resp.Body)
 | 
						body, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
	resp.Body.Close()
 | 
						resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if resp.StatusCode != 200 {
 | 
						if resp.StatusCode != 200 {
 | 
				
			||||||
		return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body)
 | 
							return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							log.Printf("got %d from %q %s", resp.StatusCode, endpoint, body)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := json.Unmarshal(body, &emails); err != nil {
 | 
						if err := json.Unmarshal(body, &emails); err != nil {
 | 
				
			||||||
| 
						 | 
					@ -185,9 +190,5 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return "", nil
 | 
						return "", errors.New("no email address found")
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (p *GitHubProvider) ValidateToken(access_token string) bool {
 | 
					 | 
				
			||||||
	return validateToken(p, access_token, nil)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,9 +7,11 @@ import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type GoogleProvider struct {
 | 
					type GoogleProvider struct {
 | 
				
			||||||
| 
						 | 
					@ -43,18 +45,11 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
 | 
				
			||||||
	return &GoogleProvider{ProviderData: p}
 | 
						return &GoogleProvider{ProviderData: p}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
					func emailFromIdToken(idToken string) (string, error) {
 | 
				
			||||||
	var response struct {
 | 
					 | 
				
			||||||
		IdToken string `json:"id_token"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := json.Unmarshal(body, &response); err != nil {
 | 
					 | 
				
			||||||
		return "", err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// id_token is a base64 encode ID token payload
 | 
						// id_token is a base64 encode ID token payload
 | 
				
			||||||
	// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | 
						// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | 
				
			||||||
	jwt := strings.Split(response.IdToken, ".")
 | 
						jwt := strings.Split(idToken, ".")
 | 
				
			||||||
	b, err := jwtDecodeSegment(jwt[1])
 | 
						b, err := jwtDecodeSegment(jwt[1])
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
| 
						 | 
					@ -62,6 +57,7 @@ func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (stri
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var email struct {
 | 
						var email struct {
 | 
				
			||||||
		Email         string `json:"email"`
 | 
							Email         string `json:"email"`
 | 
				
			||||||
 | 
							EmailVerified bool   `json:"email_verified"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(b, &email)
 | 
						err = json.Unmarshal(b, &email)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -70,6 +66,9 @@ func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (stri
 | 
				
			||||||
	if email.Email == "" {
 | 
						if email.Email == "" {
 | 
				
			||||||
		return "", errors.New("missing email")
 | 
							return "", errors.New("missing email")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if !email.EmailVerified {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("email %s not listed as verified", email.Email)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return email.Email, nil
 | 
						return email.Email, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -81,11 +80,7 @@ func jwtDecodeSegment(seg string) ([]byte, error) {
 | 
				
			||||||
	return base64.URLEncoding.DecodeString(seg)
 | 
						return base64.URLEncoding.DecodeString(seg)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *GoogleProvider) ValidateToken(access_token string) bool {
 | 
					func (p *GoogleProvider) Redeem(redirectUrl, code string) (s *SessionState, err error) {
 | 
				
			||||||
	return validateToken(p, access_token, nil)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token string, err error) {
 | 
					 | 
				
			||||||
	if code == "" {
 | 
						if code == "" {
 | 
				
			||||||
		err = errors.New("missing code")
 | 
							err = errors.New("missing code")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
| 
						 | 
					@ -108,6 +103,7 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						var body []byte
 | 
				
			||||||
	body, err = ioutil.ReadAll(resp.Body)
 | 
						body, err = ioutil.ReadAll(resp.Body)
 | 
				
			||||||
	resp.Body.Close()
 | 
						resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -122,17 +118,44 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st
 | 
				
			||||||
	var jsonResponse struct {
 | 
						var jsonResponse struct {
 | 
				
			||||||
		AccessToken  string `json:"access_token"`
 | 
							AccessToken  string `json:"access_token"`
 | 
				
			||||||
		RefreshToken string `json:"refresh_token"`
 | 
							RefreshToken string `json:"refresh_token"`
 | 
				
			||||||
 | 
							ExpiresIn    int64  `json:"expires_in"`
 | 
				
			||||||
 | 
							IdToken      string `json:"id_token"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(body, &jsonResponse)
 | 
						err = json.Unmarshal(body, &jsonResponse)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						var email string
 | 
				
			||||||
	token, err = p.redeemRefreshToken(jsonResponse.RefreshToken)
 | 
						email, err = emailFromIdToken(jsonResponse.IdToken)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s = &SessionState{
 | 
				
			||||||
 | 
							AccessToken:  jsonResponse.AccessToken,
 | 
				
			||||||
 | 
							ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
 | 
				
			||||||
 | 
							RefreshToken: jsonResponse.RefreshToken,
 | 
				
			||||||
 | 
							Email:        email,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, err error) {
 | 
					func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
 | 
				
			||||||
 | 
						if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
 | 
				
			||||||
 | 
							return false, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						newToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						origExpiration := s.ExpiresOn
 | 
				
			||||||
 | 
						s.AccessToken = newToken
 | 
				
			||||||
 | 
						s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
 | 
				
			||||||
 | 
						log.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
 | 
				
			||||||
 | 
						return true, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) {
 | 
				
			||||||
	// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
 | 
						// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
 | 
				
			||||||
	params := url.Values{}
 | 
						params := url.Values{}
 | 
				
			||||||
	params.Add("client_id", p.ClientID)
 | 
						params.Add("client_id", p.ClientID)
 | 
				
			||||||
| 
						 | 
					@ -162,12 +185,15 @@ func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string,
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var jsonResponse struct {
 | 
						var data struct {
 | 
				
			||||||
		AccessToken string `json:"access_token"`
 | 
							AccessToken string `json:"access_token"`
 | 
				
			||||||
 | 
							ExpiresIn   int64  `json:"expires_in"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(body, &jsonResponse)
 | 
						err = json.Unmarshal(body, &data)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return jsonResponse.AccessToken, nil
 | 
						token = data.AccessToken
 | 
				
			||||||
 | 
						expires = time.Duration(data.ExpiresIn) * time.Second
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,11 +3,22 @@ package providers
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newRedeemServer(body []byte) (*url.URL, *httptest.Server) {
 | 
				
			||||||
 | 
						s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							rw.Write(body)
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						u, _ := url.Parse(s.URL)
 | 
				
			||||||
 | 
						return u, s
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newGoogleProvider() *GoogleProvider {
 | 
					func newGoogleProvider() *GoogleProvider {
 | 
				
			||||||
	return NewGoogleProvider(
 | 
						return NewGoogleProvider(
 | 
				
			||||||
		&ProviderData{
 | 
							&ProviderData{
 | 
				
			||||||
| 
						 | 
					@ -66,63 +77,88 @@ func TestGoogleProviderOverrides(t *testing.T) {
 | 
				
			||||||
	assert.Equal(t, "profile", p.Data().Scope)
 | 
						assert.Equal(t, "profile", p.Data().Scope)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestGoogleProviderGetEmailAddress(t *testing.T) {
 | 
					type redeemResponse struct {
 | 
				
			||||||
	p := newGoogleProvider()
 | 
						AccessToken  string `json:"access_token"`
 | 
				
			||||||
	body, err := json.Marshal(
 | 
						RefreshToken string `json:"refresh_token"`
 | 
				
			||||||
		struct {
 | 
						ExpiresIn    int64  `json:"expires_in"`
 | 
				
			||||||
	IdToken      string `json:"id_token"`
 | 
						IdToken      string `json:"id_token"`
 | 
				
			||||||
		}{
 | 
					 | 
				
			||||||
			IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov"}`)),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
	email, err := p.GetEmailAddress(body, "ignored access_token")
 | 
					 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestGoogleProviderGetEmailAddress(t *testing.T) {
 | 
				
			||||||
 | 
						p := newGoogleProvider()
 | 
				
			||||||
 | 
						body, err := json.Marshal(redeemResponse{
 | 
				
			||||||
 | 
							AccessToken:  "a1234",
 | 
				
			||||||
 | 
							ExpiresIn:    10,
 | 
				
			||||||
 | 
							RefreshToken: "refresh12345",
 | 
				
			||||||
 | 
							IdToken:      "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)),
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						var server *httptest.Server
 | 
				
			||||||
 | 
						p.RedeemUrl, server = newRedeemServer(body)
 | 
				
			||||||
 | 
						defer server.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session, err := p.Redeem("http://redirect/", "code1234")
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						assert.NotEqual(t, session, nil)
 | 
				
			||||||
 | 
						assert.Equal(t, "michael.bland@gsa.gov", session.Email)
 | 
				
			||||||
 | 
						assert.Equal(t, "a1234", session.AccessToken)
 | 
				
			||||||
 | 
						assert.Equal(t, "refresh12345", session.RefreshToken)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
 | 
					func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
 | 
				
			||||||
	p := newGoogleProvider()
 | 
						p := newGoogleProvider()
 | 
				
			||||||
	body, err := json.Marshal(
 | 
						body, err := json.Marshal(redeemResponse{
 | 
				
			||||||
		struct {
 | 
							AccessToken: "a1234",
 | 
				
			||||||
			IdToken string `json:"id_token"`
 | 
					 | 
				
			||||||
		}{
 | 
					 | 
				
			||||||
		IdToken:     "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
 | 
							IdToken:     "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
 | 
				
			||||||
		},
 | 
						})
 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	email, err := p.GetEmailAddress(body, "ignored access_token")
 | 
						var server *httptest.Server
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						p.RedeemUrl, server = newRedeemServer(body)
 | 
				
			||||||
 | 
						defer server.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session, err := p.Redeem("http://redirect/", "code1234")
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
 | 
						if session != nil {
 | 
				
			||||||
 | 
							t.Errorf("expect nill session %#v", session)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
 | 
					func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
 | 
				
			||||||
	p := newGoogleProvider()
 | 
						p := newGoogleProvider()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	body, err := json.Marshal(
 | 
						body, err := json.Marshal(redeemResponse{
 | 
				
			||||||
		struct {
 | 
							AccessToken: "a1234",
 | 
				
			||||||
			IdToken string `json:"id_token"`
 | 
					 | 
				
			||||||
		}{
 | 
					 | 
				
			||||||
		IdToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
 | 
							IdToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
 | 
				
			||||||
		},
 | 
						})
 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	email, err := p.GetEmailAddress(body, "ignored access_token")
 | 
						var server *httptest.Server
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						p.RedeemUrl, server = newRedeemServer(body)
 | 
				
			||||||
 | 
						defer server.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session, err := p.Redeem("http://redirect/", "code1234")
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
 | 
						if session != nil {
 | 
				
			||||||
 | 
							t.Errorf("expect nill session %#v", session)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
 | 
					func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
 | 
				
			||||||
	p := newGoogleProvider()
 | 
						p := newGoogleProvider()
 | 
				
			||||||
	body, err := json.Marshal(
 | 
						body, err := json.Marshal(redeemResponse{
 | 
				
			||||||
		struct {
 | 
							AccessToken: "a1234",
 | 
				
			||||||
			IdToken string `json:"id_token"`
 | 
					 | 
				
			||||||
		}{
 | 
					 | 
				
			||||||
		IdToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
 | 
							IdToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
 | 
				
			||||||
		},
 | 
						})
 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	email, err := p.GetEmailAddress(body, "ignored access_token")
 | 
						var server *httptest.Server
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						p.RedeemUrl, server = newRedeemServer(body)
 | 
				
			||||||
 | 
						defer server.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session, err := p.Redeem("http://redirect/", "code1234")
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
 | 
						if session != nil {
 | 
				
			||||||
 | 
							t.Errorf("expect nill session %#v", session)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,6 +9,7 @@ import (
 | 
				
			||||||
	"github.com/bitly/oauth2_proxy/api"
 | 
						"github.com/bitly/oauth2_proxy/api"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// validateToken returns true if token is valid
 | 
				
			||||||
func validateToken(p Provider, access_token string, header http.Header) bool {
 | 
					func validateToken(p Provider, access_token string, header http.Header) bool {
 | 
				
			||||||
	if access_token == "" || p.Data().ValidateUrl == nil {
 | 
						if access_token == "" || p.Data().ValidateUrl == nil {
 | 
				
			||||||
		return false
 | 
							return false
 | 
				
			||||||
| 
						 | 
					@ -20,12 +21,15 @@ func validateToken(p Provider, access_token string, header http.Header) bool {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	resp, err := api.RequestUnparsedResponse(endpoint, header)
 | 
						resp, err := api.RequestUnparsedResponse(endpoint, header)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Printf("GET %s", endpoint)
 | 
				
			||||||
		log.Printf("token validation request failed: %s", err)
 | 
							log.Printf("token validation request failed: %s", err)
 | 
				
			||||||
		return false
 | 
							return false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	body, _ := ioutil.ReadAll(resp.Body)
 | 
						body, _ := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
	resp.Body.Close()
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						log.Printf("%d GET %s %s", resp.StatusCode, endpoint, body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if resp.StatusCode == 200 {
 | 
						if resp.StatusCode == 200 {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,36 +1,38 @@
 | 
				
			||||||
package providers
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
						"errors"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ValidateTokenTestProvider struct {
 | 
					type ValidateSessionStateTestProvider struct {
 | 
				
			||||||
	*ProviderData
 | 
						*ProviderData
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
					func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) {
 | 
				
			||||||
	return "", nil
 | 
						return "", errors.New("not implemented")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Note that we're testing the internal validateToken() used to implement
 | 
					// Note that we're testing the internal validateToken() used to implement
 | 
				
			||||||
// several Provider's ValidateToken() implementations
 | 
					// several Provider's ValidateSessionState() implementations
 | 
				
			||||||
func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool {
 | 
					func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool {
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ValidateTokenTest struct {
 | 
					type ValidateSessionStateTest struct {
 | 
				
			||||||
	backend       *httptest.Server
 | 
						backend       *httptest.Server
 | 
				
			||||||
	response_code int
 | 
						response_code int
 | 
				
			||||||
	provider      *ValidateTokenTestProvider
 | 
						provider      *ValidateSessionStateTestProvider
 | 
				
			||||||
	header        http.Header
 | 
						header        http.Header
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewValidateTokenTest() *ValidateTokenTest {
 | 
					func NewValidateSessionStateTest() *ValidateSessionStateTest {
 | 
				
			||||||
	var vt_test ValidateTokenTest
 | 
						var vt_test ValidateSessionStateTest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	vt_test.backend = httptest.NewServer(
 | 
						vt_test.backend = httptest.NewServer(
 | 
				
			||||||
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
							http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
| 
						 | 
					@ -59,7 +61,7 @@ func NewValidateTokenTest() *ValidateTokenTest {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		}))
 | 
							}))
 | 
				
			||||||
	backend_url, _ := url.Parse(vt_test.backend.URL)
 | 
						backend_url, _ := url.Parse(vt_test.backend.URL)
 | 
				
			||||||
	vt_test.provider = &ValidateTokenTestProvider{
 | 
						vt_test.provider = &ValidateSessionStateTestProvider{
 | 
				
			||||||
		ProviderData: &ProviderData{
 | 
							ProviderData: &ProviderData{
 | 
				
			||||||
			ValidateUrl: &url.URL{
 | 
								ValidateUrl: &url.URL{
 | 
				
			||||||
				Scheme: "http",
 | 
									Scheme: "http",
 | 
				
			||||||
| 
						 | 
					@ -72,18 +74,18 @@ func NewValidateTokenTest() *ValidateTokenTest {
 | 
				
			||||||
	return &vt_test
 | 
						return &vt_test
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (vt_test *ValidateTokenTest) Close() {
 | 
					func (vt_test *ValidateSessionStateTest) Close() {
 | 
				
			||||||
	vt_test.backend.Close()
 | 
						vt_test.backend.Close()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestValidateTokenValidToken(t *testing.T) {
 | 
					func TestValidateSessionStateValidToken(t *testing.T) {
 | 
				
			||||||
	vt_test := NewValidateTokenTest()
 | 
						vt_test := NewValidateSessionStateTest()
 | 
				
			||||||
	defer vt_test.Close()
 | 
						defer vt_test.Close()
 | 
				
			||||||
	assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil))
 | 
						assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestValidateTokenValidTokenWithHeaders(t *testing.T) {
 | 
					func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
 | 
				
			||||||
	vt_test := NewValidateTokenTest()
 | 
						vt_test := NewValidateSessionStateTest()
 | 
				
			||||||
	defer vt_test.Close()
 | 
						defer vt_test.Close()
 | 
				
			||||||
	vt_test.header = make(http.Header)
 | 
						vt_test.header = make(http.Header)
 | 
				
			||||||
	vt_test.header.Set("Authorization", "Bearer foobar")
 | 
						vt_test.header.Set("Authorization", "Bearer foobar")
 | 
				
			||||||
| 
						 | 
					@ -91,28 +93,28 @@ func TestValidateTokenValidTokenWithHeaders(t *testing.T) {
 | 
				
			||||||
		validateToken(vt_test.provider, "foobar", vt_test.header))
 | 
							validateToken(vt_test.provider, "foobar", vt_test.header))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestValidateTokenEmptyToken(t *testing.T) {
 | 
					func TestValidateSessionStateEmptyToken(t *testing.T) {
 | 
				
			||||||
	vt_test := NewValidateTokenTest()
 | 
						vt_test := NewValidateSessionStateTest()
 | 
				
			||||||
	defer vt_test.Close()
 | 
						defer vt_test.Close()
 | 
				
			||||||
	assert.Equal(t, false, validateToken(vt_test.provider, "", nil))
 | 
						assert.Equal(t, false, validateToken(vt_test.provider, "", nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestValidateTokenEmptyValidateUrl(t *testing.T) {
 | 
					func TestValidateSessionStateEmptyValidateUrl(t *testing.T) {
 | 
				
			||||||
	vt_test := NewValidateTokenTest()
 | 
						vt_test := NewValidateSessionStateTest()
 | 
				
			||||||
	defer vt_test.Close()
 | 
						defer vt_test.Close()
 | 
				
			||||||
	vt_test.provider.Data().ValidateUrl = nil
 | 
						vt_test.provider.Data().ValidateUrl = nil
 | 
				
			||||||
	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
 | 
						assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestValidateTokenRequestNetworkFailure(t *testing.T) {
 | 
					func TestValidateSessionStateRequestNetworkFailure(t *testing.T) {
 | 
				
			||||||
	vt_test := NewValidateTokenTest()
 | 
						vt_test := NewValidateSessionStateTest()
 | 
				
			||||||
	// Close immediately to simulate a network failure
 | 
						// Close immediately to simulate a network failure
 | 
				
			||||||
	vt_test.Close()
 | 
						vt_test.Close()
 | 
				
			||||||
	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
 | 
						assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestValidateTokenExpiredToken(t *testing.T) {
 | 
					func TestValidateSessionStateExpiredToken(t *testing.T) {
 | 
				
			||||||
	vt_test := NewValidateTokenTest()
 | 
						vt_test := NewValidateSessionStateTest()
 | 
				
			||||||
	defer vt_test.Close()
 | 
						defer vt_test.Close()
 | 
				
			||||||
	vt_test.response_code = 401
 | 
						vt_test.response_code = 401
 | 
				
			||||||
	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
 | 
						assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,6 @@
 | 
				
			||||||
package providers
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
					 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
| 
						 | 
					@ -49,16 +48,15 @@ func getLinkedInHeader(access_token string) http.Header {
 | 
				
			||||||
	return header
 | 
						return header
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
					func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
 | 
				
			||||||
	if access_token == "" {
 | 
						if s.AccessToken == "" {
 | 
				
			||||||
		return "", errors.New("missing access token")
 | 
							return "", errors.New("missing access token")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	params := url.Values{}
 | 
						req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", nil)
 | 
				
			||||||
	req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", bytes.NewBufferString(params.Encode()))
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req.Header = getLinkedInHeader(access_token)
 | 
						req.Header = getLinkedInHeader(s.AccessToken)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	json, err := api.Request(req)
 | 
						json, err := api.Request(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -74,6 +72,6 @@ func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (st
 | 
				
			||||||
	return email, nil
 | 
						return email, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *LinkedInProvider) ValidateToken(access_token string) bool {
 | 
					func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
 | 
				
			||||||
	return validateToken(p, access_token, getLinkedInHeader(access_token))
 | 
						return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -97,8 +97,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testLinkedInProvider(b_url.Host)
 | 
						p := testLinkedInProvider(b_url.Host)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.GetEmailAddress([]byte{},
 | 
						session := &SessionState{AccessToken: "imaginary_access_token"}
 | 
				
			||||||
		"imaginary_access_token")
 | 
						email, err := p.GetEmailAddress(session)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "user@linkedin.com", email)
 | 
						assert.Equal(t, "user@linkedin.com", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -113,7 +113,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
 | 
				
			||||||
	// We'll trigger a request failure by using an unexpected access
 | 
						// We'll trigger a request failure by using an unexpected access
 | 
				
			||||||
	// token. Alternatively, we could allow the parsing of the payload as
 | 
						// token. Alternatively, we could allow the parsing of the payload as
 | 
				
			||||||
	// JSON to fail.
 | 
						// JSON to fail.
 | 
				
			||||||
	email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token")
 | 
						session := &SessionState{AccessToken: "unexpected_access_token"}
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(session)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -125,7 +126,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testLinkedInProvider(b_url.Host)
 | 
						p := testLinkedInProvider(b_url.Host)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
 | 
						session := &SessionState{AccessToken: "imaginary_access_token"}
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(session)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -42,9 +42,9 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider {
 | 
				
			||||||
	return &MyUsaProvider{ProviderData: p}
 | 
						return &MyUsaProvider{ProviderData: p}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (string, error) {
 | 
					func (p *MyUsaProvider) GetEmailAddress(s *SessionState) (string, error) {
 | 
				
			||||||
	req, err := http.NewRequest("GET",
 | 
						req, err := http.NewRequest("GET",
 | 
				
			||||||
		p.ProfileUrl.String()+"?access_token="+access_token, nil)
 | 
							p.ProfileUrl.String()+"?access_token="+s.AccessToken, nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Printf("failed building request %s", err)
 | 
							log.Printf("failed building request %s", err)
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
| 
						 | 
					@ -56,7 +56,3 @@ func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (strin
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return json.Get("email").String()
 | 
						return json.Get("email").String()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (p *MyUsaProvider) ValidateToken(access_token string) bool {
 | 
					 | 
				
			||||||
	return validateToken(p, access_token, nil)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,11 +1,12 @@
 | 
				
			||||||
package providers
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/bmizerany/assert"
 | 
					 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func updateUrl(url *url.URL, hostname string) {
 | 
					func updateUrl(url *url.URL, hostname string) {
 | 
				
			||||||
| 
						 | 
					@ -102,7 +103,8 @@ func TestMyUsaProviderGetEmailAddress(t *testing.T) {
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testMyUsaProvider(b_url.Host)
 | 
						p := testMyUsaProvider(b_url.Host)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
 | 
						session := &SessionState{AccessToken: "imaginary_access_token"}
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(session)
 | 
				
			||||||
	assert.Equal(t, nil, err)
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
						assert.Equal(t, "michael.bland@gsa.gov", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -119,7 +121,8 @@ func TestMyUsaProviderGetEmailAddressFailedRequest(t *testing.T) {
 | 
				
			||||||
	// We'll trigger a request failure by using an unexpected access
 | 
						// We'll trigger a request failure by using an unexpected access
 | 
				
			||||||
	// token. Alternatively, we could allow the parsing of the payload as
 | 
						// token. Alternatively, we could allow the parsing of the payload as
 | 
				
			||||||
	// JSON to fail.
 | 
						// JSON to fail.
 | 
				
			||||||
	email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token")
 | 
						session := &SessionState{AccessToken: "unexpected_access_token"}
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(session)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -131,7 +134,8 @@ func TestMyUsaProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
 | 
				
			||||||
	b_url, _ := url.Parse(b.URL)
 | 
						b_url, _ := url.Parse(b.URL)
 | 
				
			||||||
	p := testMyUsaProvider(b_url.Host)
 | 
						p := testMyUsaProvider(b_url.Host)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token")
 | 
						session := &SessionState{AccessToken: "imaginary_access_token"}
 | 
				
			||||||
 | 
						email, err := p.GetEmailAddress(session)
 | 
				
			||||||
	assert.NotEqual(t, nil, err)
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
	assert.Equal(t, "", email)
 | 
						assert.Equal(t, "", email)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,9 +9,11 @@ import (
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bitly/oauth2_proxy/cookie"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) {
 | 
					func (p *ProviderData) Redeem(redirectUrl, code string) (s *SessionState, err error) {
 | 
				
			||||||
	if code == "" {
 | 
						if code == "" {
 | 
				
			||||||
		err = errors.New("missing code")
 | 
							err = errors.New("missing code")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
| 
						 | 
					@ -23,24 +25,28 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri
 | 
				
			||||||
	params.Add("client_secret", p.ClientSecret)
 | 
						params.Add("client_secret", p.ClientSecret)
 | 
				
			||||||
	params.Add("code", code)
 | 
						params.Add("code", code)
 | 
				
			||||||
	params.Add("grant_type", "authorization_code")
 | 
						params.Add("grant_type", "authorization_code")
 | 
				
			||||||
	req, err := http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode()))
 | 
						var req *http.Request
 | 
				
			||||||
 | 
						req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode()))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, "", err
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
						req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp, err := http.DefaultClient.Do(req)
 | 
						var resp *http.Response
 | 
				
			||||||
 | 
						resp, err = http.DefaultClient.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, "", err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						var body []byte
 | 
				
			||||||
	body, err = ioutil.ReadAll(resp.Body)
 | 
						body, err = ioutil.ReadAll(resp.Body)
 | 
				
			||||||
	resp.Body.Close()
 | 
						resp.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, "", err
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if resp.StatusCode != 200 {
 | 
						if resp.StatusCode != 200 {
 | 
				
			||||||
		return body, "", fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body)
 | 
							err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// blindly try json and x-www-form-urlencoded
 | 
						// blindly try json and x-www-form-urlencoded
 | 
				
			||||||
| 
						 | 
					@ -49,11 +55,23 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(body, &jsonResponse)
 | 
						err = json.Unmarshal(body, &jsonResponse)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		return body, jsonResponse.AccessToken, nil
 | 
							s = &SessionState{
 | 
				
			||||||
 | 
								AccessToken: jsonResponse.AccessToken,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	v, err := url.ParseQuery(string(body))
 | 
						var v url.Values
 | 
				
			||||||
	return body, v.Get("access_token"), err
 | 
						v, err = url.ParseQuery(string(body))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if a := v.Get("access_token"); a != "" {
 | 
				
			||||||
 | 
							s = &SessionState{AccessToken: a}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							err = fmt.Errorf("no access token found %s", body)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetLoginURL with typical oauth parameters
 | 
					// GetLoginURL with typical oauth parameters
 | 
				
			||||||
| 
						 | 
					@ -72,3 +90,26 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
 | 
				
			||||||
	a.RawQuery = params.Encode()
 | 
						a.RawQuery = params.Encode()
 | 
				
			||||||
	return a.String()
 | 
						return a.String()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CookieForSession serializes a session state for storage in a cookie
 | 
				
			||||||
 | 
					func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) {
 | 
				
			||||||
 | 
						return s.EncodeSessionState(c)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SessionFromCookie deserializes a session from a cookie value
 | 
				
			||||||
 | 
					func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) {
 | 
				
			||||||
 | 
						return DecodeSessionState(v, c)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
 | 
				
			||||||
 | 
						return "", errors.New("not implemented")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
 | 
				
			||||||
 | 
						return validateToken(p, s.AccessToken, nil)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RefreshSessionIfNeeded
 | 
				
			||||||
 | 
					func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
 | 
				
			||||||
 | 
						return false, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,17 @@
 | 
				
			||||||
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRefresh(t *testing.T) {
 | 
				
			||||||
 | 
						p := &ProviderData{}
 | 
				
			||||||
 | 
						refreshed, err := p.RefreshSessionIfNeeded(&SessionState{
 | 
				
			||||||
 | 
							ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						assert.Equal(t, false, refreshed)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -1,11 +1,18 @@
 | 
				
			||||||
package providers
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/bitly/oauth2_proxy/cookie"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Provider interface {
 | 
					type Provider interface {
 | 
				
			||||||
	Data() *ProviderData
 | 
						Data() *ProviderData
 | 
				
			||||||
	GetEmailAddress(body []byte, access_token string) (string, error)
 | 
						GetEmailAddress(*SessionState) (string, error)
 | 
				
			||||||
	Redeem(string, string) ([]byte, string, error)
 | 
						Redeem(string, string) (*SessionState, error)
 | 
				
			||||||
	ValidateToken(access_token string) bool
 | 
						ValidateSessionState(*SessionState) bool
 | 
				
			||||||
	GetLoginURL(redirectURI, finalRedirect string) string
 | 
						GetLoginURL(redirectURI, finalRedirect string) string
 | 
				
			||||||
 | 
						RefreshSessionIfNeeded(*SessionState) (bool, error)
 | 
				
			||||||
 | 
						SessionFromCookie(string, *cookie.Cipher) (*SessionState, error)
 | 
				
			||||||
 | 
						CookieForSession(*SessionState, *cookie.Cipher) (string, error)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func New(provider string, p *ProviderData) Provider {
 | 
					func New(provider string, p *ProviderData) Provider {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,115 @@
 | 
				
			||||||
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bitly/oauth2_proxy/cookie"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SessionState struct {
 | 
				
			||||||
 | 
						AccessToken  string
 | 
				
			||||||
 | 
						ExpiresOn    time.Time
 | 
				
			||||||
 | 
						RefreshToken string
 | 
				
			||||||
 | 
						Email        string
 | 
				
			||||||
 | 
						User         string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SessionState) IsExpired() bool {
 | 
				
			||||||
 | 
						if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SessionState) String() string {
 | 
				
			||||||
 | 
						o := fmt.Sprintf("Session{%s", s.userOrEmail())
 | 
				
			||||||
 | 
						if s.AccessToken != "" {
 | 
				
			||||||
 | 
							o += " token:true"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !s.ExpiresOn.IsZero() {
 | 
				
			||||||
 | 
							o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if s.RefreshToken != "" {
 | 
				
			||||||
 | 
							o += " refresh_token:true"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return o + "}"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
 | 
				
			||||||
 | 
						if c == nil || s.AccessToken == "" {
 | 
				
			||||||
 | 
							return s.userOrEmail(), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return s.EncryptedString(c)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SessionState) userOrEmail() string {
 | 
				
			||||||
 | 
						u := s.User
 | 
				
			||||||
 | 
						if s.Email != "" {
 | 
				
			||||||
 | 
							u = s.Email
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return u
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						if c == nil {
 | 
				
			||||||
 | 
							panic("error. missing cipher")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						a := s.AccessToken
 | 
				
			||||||
 | 
						if a != "" {
 | 
				
			||||||
 | 
							a, err = c.Encrypt(a)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						r := s.RefreshToken
 | 
				
			||||||
 | 
						if r != "" {
 | 
				
			||||||
 | 
							r, err = c.Encrypt(r)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
 | 
				
			||||||
 | 
						chunks := strings.Split(v, "|")
 | 
				
			||||||
 | 
						if len(chunks) == 1 {
 | 
				
			||||||
 | 
							if strings.Contains(chunks[0], "@") {
 | 
				
			||||||
 | 
								u := strings.Split(v, "@")[0]
 | 
				
			||||||
 | 
								return &SessionState{Email: v, User: u}, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return &SessionState{User: v}, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(chunks) != 4 {
 | 
				
			||||||
 | 
							err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s = &SessionState{}
 | 
				
			||||||
 | 
						if c != nil && chunks[1] != "" {
 | 
				
			||||||
 | 
							s.AccessToken, err = c.Decrypt(chunks[1])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if c != nil && chunks[3] != "" {
 | 
				
			||||||
 | 
							s.RefreshToken, err = c.Decrypt(chunks[3])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if u := chunks[0]; strings.Contains(u, "@") {
 | 
				
			||||||
 | 
							s.Email = u
 | 
				
			||||||
 | 
							s.User = strings.Split(u, "@")[0]
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							s.User = u
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ts, _ := strconv.Atoi(chunks[2])
 | 
				
			||||||
 | 
						s.ExpiresOn = time.Unix(int64(ts), 0)
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,88 @@
 | 
				
			||||||
 | 
					package providers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bitly/oauth2_proxy/cookie"
 | 
				
			||||||
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const secret = "0123456789abcdefghijklmnopqrstuv"
 | 
				
			||||||
 | 
					const altSecret = "0000000000abcdefghijklmnopqrstuv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSessionStateSerialization(t *testing.T) {
 | 
				
			||||||
 | 
						c, err := cookie.NewCipher(secret)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						c2, err := cookie.NewCipher(altSecret)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						s := &SessionState{
 | 
				
			||||||
 | 
							Email:        "user@domain.com",
 | 
				
			||||||
 | 
							AccessToken:  "token1234",
 | 
				
			||||||
 | 
							ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour),
 | 
				
			||||||
 | 
							RefreshToken: "refresh4321",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						encoded, err := s.EncodeSessionState(c)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						assert.Equal(t, 3, strings.Count(encoded, "|"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ss, err := DecodeSessionState(encoded, c)
 | 
				
			||||||
 | 
						t.Logf("%#v", ss)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						assert.Equal(t, s.Email, ss.Email)
 | 
				
			||||||
 | 
						assert.Equal(t, s.AccessToken, ss.AccessToken)
 | 
				
			||||||
 | 
						assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
 | 
				
			||||||
 | 
						assert.Equal(t, s.RefreshToken, ss.RefreshToken)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 
				
			||||||
 | 
						ss, err = DecodeSessionState(encoded, c2)
 | 
				
			||||||
 | 
						t.Logf("%#v", ss)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						assert.Equal(t, s.Email, ss.Email)
 | 
				
			||||||
 | 
						assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
 | 
				
			||||||
 | 
						assert.NotEqual(t, s.AccessToken, ss.AccessToken)
 | 
				
			||||||
 | 
						assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSessionStateSerializationNoCipher(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s := &SessionState{
 | 
				
			||||||
 | 
							Email:        "user@domain.com",
 | 
				
			||||||
 | 
							AccessToken:  "token1234",
 | 
				
			||||||
 | 
							ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour),
 | 
				
			||||||
 | 
							RefreshToken: "refresh4321",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						encoded, err := s.EncodeSessionState(nil)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						assert.Equal(t, s.Email, encoded)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// only email should have been serialized
 | 
				
			||||||
 | 
						ss, err := DecodeSessionState(encoded, nil)
 | 
				
			||||||
 | 
						assert.Equal(t, nil, err)
 | 
				
			||||||
 | 
						assert.Equal(t, s.Email, ss.Email)
 | 
				
			||||||
 | 
						assert.Equal(t, "", ss.AccessToken)
 | 
				
			||||||
 | 
						assert.Equal(t, "", ss.RefreshToken)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSessionStateUserOrEmail(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s := &SessionState{
 | 
				
			||||||
 | 
							Email: "user@domain.com",
 | 
				
			||||||
 | 
							User:  "just-user",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						assert.Equal(t, "user@domain.com", s.userOrEmail())
 | 
				
			||||||
 | 
						s.Email = ""
 | 
				
			||||||
 | 
						assert.Equal(t, "just-user", s.userOrEmail())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestExpired(t *testing.T) {
 | 
				
			||||||
 | 
						s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
 | 
				
			||||||
 | 
						assert.Equal(t, true, s.IsExpired())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
 | 
				
			||||||
 | 
						assert.Equal(t, false, s.IsExpired())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s = &SessionState{}
 | 
				
			||||||
 | 
						assert.Equal(t, false, s.IsExpired())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue