Merge pull request #148 from pusher/proxy-session-store
Proxy session store
This commit is contained in:
		
						commit
						10e240c8bf
					
				|  | @ -10,6 +10,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v3.2.0 | ## Changes since v3.2.0 | ||||||
| 
 | 
 | ||||||
|  | - [#148](https://github.com/pusher/outh2_proxy/pull/148) Implement SessionStore interface within proxy (@JoelSpeed) | ||||||
| - [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed) | - [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed) | ||||||
|   - Allows for multiple different session storage implementations including client and server side |   - Allows for multiple different session storage implementations including client and server side | ||||||
|   - Adds tests suite for interface to ensure consistency across implementations |   - Adds tests suite for interface to ensure consistency across implementations | ||||||
|  |  | ||||||
							
								
								
									
										191
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										191
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -16,7 +16,7 @@ import ( | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/pusher/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"github.com/yhat/wsutil" | 	"github.com/yhat/wsutil" | ||||||
| ) | ) | ||||||
|  | @ -29,10 +29,6 @@ const ( | ||||||
| 	httpScheme  = "http" | 	httpScheme  = "http" | ||||||
| 	httpsScheme = "https" | 	httpsScheme = "https" | ||||||
| 
 | 
 | ||||||
| 	// Cookies are limited to 4kb including the length of the cookie name,
 |  | ||||||
| 	// the cookie name can be up to 256 bytes
 |  | ||||||
| 	maxCookieLength = 3840 |  | ||||||
| 
 |  | ||||||
| 	applicationJSON = "application/json" | 	applicationJSON = "application/json" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -75,6 +71,7 @@ type OAuthProxy struct { | ||||||
| 	redirectURL         *url.URL // the url to receive requests at
 | 	redirectURL         *url.URL // the url to receive requests at
 | ||||||
| 	whitelistDomains    []string | 	whitelistDomains    []string | ||||||
| 	provider            providers.Provider | 	provider            providers.Provider | ||||||
|  | 	sessionStore        sessionsapi.SessionStore | ||||||
| 	ProxyPrefix         string | 	ProxyPrefix         string | ||||||
| 	SignInMessage       string | 	SignInMessage       string | ||||||
| 	HtpasswdFile        *HtpasswdFile | 	HtpasswdFile        *HtpasswdFile | ||||||
|  | @ -88,7 +85,6 @@ type OAuthProxy struct { | ||||||
| 	PassAccessToken     bool | 	PassAccessToken     bool | ||||||
| 	SetAuthorization    bool | 	SetAuthorization    bool | ||||||
| 	PassAuthorization   bool | 	PassAuthorization   bool | ||||||
| 	CookieCipher        *cookie.Cipher |  | ||||||
| 	skipAuthRegex       []string | 	skipAuthRegex       []string | ||||||
| 	skipAuthPreflight   bool | 	skipAuthPreflight   bool | ||||||
| 	compiledRegex       []*regexp.Regexp | 	compiledRegex       []*regexp.Regexp | ||||||
|  | @ -218,15 +214,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s path:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, opts.CookiePath, refresh) | 	logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s path:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, opts.CookiePath, refresh) | ||||||
| 
 | 
 | ||||||
| 	var cipher *cookie.Cipher |  | ||||||
| 	if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) { |  | ||||||
| 		var err error |  | ||||||
| 		cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Fatal("cookie-secret error: ", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return &OAuthProxy{ | 	return &OAuthProxy{ | ||||||
| 		CookieName:     opts.CookieName, | 		CookieName:     opts.CookieName, | ||||||
| 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), | 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), | ||||||
|  | @ -249,6 +236,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 
 | 
 | ||||||
| 		ProxyPrefix:        opts.ProxyPrefix, | 		ProxyPrefix:        opts.ProxyPrefix, | ||||||
| 		provider:           opts.provider, | 		provider:           opts.provider, | ||||||
|  | 		sessionStore:       opts.sessionStore, | ||||||
| 		serveMux:           serveMux, | 		serveMux:           serveMux, | ||||||
| 		redirectURL:        redirectURL, | 		redirectURL:        redirectURL, | ||||||
| 		whitelistDomains:   opts.WhitelistDomains, | 		whitelistDomains:   opts.WhitelistDomains, | ||||||
|  | @ -263,7 +251,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 		SetAuthorization:   opts.SetAuthorization, | 		SetAuthorization:   opts.SetAuthorization, | ||||||
| 		PassAuthorization:  opts.PassAuthorization, | 		PassAuthorization:  opts.PassAuthorization, | ||||||
| 		SkipProviderButton: opts.SkipProviderButton, | 		SkipProviderButton: opts.SkipProviderButton, | ||||||
| 		CookieCipher:       cipher, |  | ||||||
| 		templates:          loadTemplates(opts.CustomTemplatesDir), | 		templates:          loadTemplates(opts.CustomTemplatesDir), | ||||||
| 		Footer:             opts.Footer, | 		Footer:             opts.Footer, | ||||||
| 	} | 	} | ||||||
|  | @ -293,7 +280,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { | ||||||
| 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { | func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, errors.New("missing code") | 		return nil, errors.New("missing code") | ||||||
| 	} | 	} | ||||||
|  | @ -316,104 +303,6 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, er | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // MakeSessionCookie creates an http.Cookie containing the authenticated user's
 |  | ||||||
| // authentication details
 |  | ||||||
| func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { |  | ||||||
| 	if value != "" { |  | ||||||
| 		value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) |  | ||||||
| 	} |  | ||||||
| 	c := p.makeCookie(req, p.CookieName, value, expiration, now) |  | ||||||
| 	if len(c.Value) > 4096-len(p.CookieName) { |  | ||||||
| 		return splitCookie(c) |  | ||||||
| 	} |  | ||||||
| 	return []*http.Cookie{c} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func copyCookie(c *http.Cookie) *http.Cookie { |  | ||||||
| 	return &http.Cookie{ |  | ||||||
| 		Name:       c.Name, |  | ||||||
| 		Value:      c.Value, |  | ||||||
| 		Path:       c.Path, |  | ||||||
| 		Domain:     c.Domain, |  | ||||||
| 		Expires:    c.Expires, |  | ||||||
| 		RawExpires: c.RawExpires, |  | ||||||
| 		MaxAge:     c.MaxAge, |  | ||||||
| 		Secure:     c.Secure, |  | ||||||
| 		HttpOnly:   c.HttpOnly, |  | ||||||
| 		Raw:        c.Raw, |  | ||||||
| 		Unparsed:   c.Unparsed, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // splitCookie reads the full cookie generated to store the session and splits
 |  | ||||||
| // it into a slice of cookies which fit within the 4kb cookie limit indexing
 |  | ||||||
| // the cookies from 0
 |  | ||||||
| func splitCookie(c *http.Cookie) []*http.Cookie { |  | ||||||
| 	if len(c.Value) < maxCookieLength { |  | ||||||
| 		return []*http.Cookie{c} |  | ||||||
| 	} |  | ||||||
| 	cookies := []*http.Cookie{} |  | ||||||
| 	valueBytes := []byte(c.Value) |  | ||||||
| 	count := 0 |  | ||||||
| 	for len(valueBytes) > 0 { |  | ||||||
| 		new := copyCookie(c) |  | ||||||
| 		new.Name = fmt.Sprintf("%s_%d", c.Name, count) |  | ||||||
| 		count++ |  | ||||||
| 		if len(valueBytes) < maxCookieLength { |  | ||||||
| 			new.Value = string(valueBytes) |  | ||||||
| 			valueBytes = []byte{} |  | ||||||
| 		} else { |  | ||||||
| 			newValue := valueBytes[:maxCookieLength] |  | ||||||
| 			valueBytes = valueBytes[maxCookieLength:] |  | ||||||
| 			new.Value = string(newValue) |  | ||||||
| 		} |  | ||||||
| 		cookies = append(cookies, new) |  | ||||||
| 	} |  | ||||||
| 	return cookies |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // joinCookies takes a slice of cookies from the request and reconstructs the
 |  | ||||||
| // full session cookie
 |  | ||||||
| func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { |  | ||||||
| 	if len(cookies) == 0 { |  | ||||||
| 		return nil, fmt.Errorf("list of cookies must be > 0") |  | ||||||
| 	} |  | ||||||
| 	if len(cookies) == 1 { |  | ||||||
| 		return cookies[0], nil |  | ||||||
| 	} |  | ||||||
| 	c := copyCookie(cookies[0]) |  | ||||||
| 	for i := 1; i < len(cookies); i++ { |  | ||||||
| 		c.Value += cookies[i].Value |  | ||||||
| 	} |  | ||||||
| 	c.Name = strings.TrimRight(c.Name, "_0") |  | ||||||
| 	return c, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // loadCookie retreieves the sessions state cookie from the http request.
 |  | ||||||
| // If a single cookie is present this will be returned, otherwise it attempts
 |  | ||||||
| // to reconstruct a cookie split up by splitCookie
 |  | ||||||
| func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { |  | ||||||
| 	c, err := req.Cookie(cookieName) |  | ||||||
| 	if err == nil { |  | ||||||
| 		return c, nil |  | ||||||
| 	} |  | ||||||
| 	cookies := []*http.Cookie{} |  | ||||||
| 	err = nil |  | ||||||
| 	count := 0 |  | ||||||
| 	for err == nil { |  | ||||||
| 		var c *http.Cookie |  | ||||||
| 		c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count)) |  | ||||||
| 		if err == nil { |  | ||||||
| 			cookies = append(cookies, c) |  | ||||||
| 			count++ |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if len(cookies) == 0 { |  | ||||||
| 		return nil, fmt.Errorf("Could not find cookie %s", cookieName) |  | ||||||
| 	} |  | ||||||
| 	return joinCookies(cookies) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // MakeCSRFCookie creates a cookie for CSRF
 | // MakeCSRFCookie creates a cookie for CSRF
 | ||||||
| func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||||
| 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) | 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) | ||||||
|  | @ -454,66 +343,18 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va | ||||||
| 
 | 
 | ||||||
| // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | ||||||
| // stored in the user's session
 | // stored in the user's session
 | ||||||
| func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { | ||||||
| 	var cookies []*http.Cookie | 	return p.sessionStore.Clear(rw, req) | ||||||
| 
 |  | ||||||
| 	// matches CookieName, CookieName_<number>
 |  | ||||||
| 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName)) |  | ||||||
| 
 |  | ||||||
| 	for _, c := range req.Cookies() { |  | ||||||
| 		if cookieNameRegex.MatchString(c.Name) { |  | ||||||
| 			clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) |  | ||||||
| 
 |  | ||||||
| 			http.SetCookie(rw, clearCookie) |  | ||||||
| 			cookies = append(cookies, clearCookie) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// ugly hack because default domain changed
 |  | ||||||
| 	if p.CookieDomain == "" && len(cookies) > 0 { |  | ||||||
| 		clr2 := *cookies[0] |  | ||||||
| 		clr2.Domain = req.Host |  | ||||||
| 		http.SetCookie(rw, &clr2) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // SetSessionCookie adds the user's session cookie to the response
 |  | ||||||
| func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { |  | ||||||
| 	for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) { |  | ||||||
| 		http.SetCookie(rw, c) |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // LoadCookiedSession reads the user's authentication details from the request
 | // LoadCookiedSession reads the user's authentication details from the request
 | ||||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { | func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
| 	var age time.Duration | 	return p.sessionStore.Load(req) | ||||||
| 	c, err := loadCookie(req, p.CookieName) |  | ||||||
| 	if err != nil { |  | ||||||
| 		// always http.ErrNoCookie
 |  | ||||||
| 		return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) |  | ||||||
| 	} |  | ||||||
| 	val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil, age, errors.New("Cookie Signature not valid") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	session, err := p.provider.SessionFromCookie(val, p.CookieCipher) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, age, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	age = time.Now().Truncate(time.Second).Sub(timestamp) |  | ||||||
| 	return session, age, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SaveSession creates a new session cookie value and sets this on the response
 | // SaveSession creates a new session cookie value and sets this on the response
 | ||||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { | ||||||
| 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | 	return p.sessionStore.Save(rw, req, s) | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	p.SetSessionCookie(rw, req, value) |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RobotsTxt disallows scraping pages from the OAuthProxy
 | // RobotsTxt disallows scraping pages from the OAuthProxy
 | ||||||
|  | @ -694,7 +535,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||||
| 
 | 
 | ||||||
| 	user, ok := p.ManualSignIn(rw, req) | 	user, ok := p.ManualSignIn(rw, req) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		session := &sessions.SessionState{User: user} | 		session := &sessionsapi.SessionState{User: user} | ||||||
| 		p.SaveSession(rw, req, session) | 		p.SaveSession(rw, req, session) | ||||||
| 		http.Redirect(rw, req, redirect, 302) | 		http.Redirect(rw, req, redirect, 302) | ||||||
| 	} else { | 	} else { | ||||||
|  | @ -833,12 +674,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 	var saveSession, clearSession, revalidated bool | 	var saveSession, clearSession, revalidated bool | ||||||
| 	remoteAddr := getRemoteAddr(req) | 	remoteAddr := getRemoteAddr(req) | ||||||
| 
 | 
 | ||||||
| 	session, sessionAge, err := p.LoadCookiedSession(req) | 	session, err := p.LoadCookiedSession(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("Error loading cookied session: %s", err) | 		logger.Printf("Error loading cookied session: %s", err) | ||||||
| 	} | 	} | ||||||
| 	if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { | 	if session != nil && session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { | ||||||
| 		logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", sessionAge, session, p.CookieRefresh) | 		logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh) | ||||||
| 		saveSession = true | 		saveSession = true | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -945,7 +786,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 
 | 
 | ||||||
| // CheckBasicAuth checks the requests Authorization header for basic auth
 | // CheckBasicAuth checks the requests Authorization header for basic auth
 | ||||||
| // credentials and authenticates these against the proxies HtpasswdFile
 | // credentials and authenticates these against the proxies HtpasswdFile
 | ||||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { | func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
| 	if p.HtpasswdFile == nil { | 	if p.HtpasswdFile == nil { | ||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  | @ -967,7 +808,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, | ||||||
| 	} | 	} | ||||||
| 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | ||||||
| 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||||
| 		return &sessions.SessionState{User: pair[0]}, nil | 		return &sessionsapi.SessionState{User: pair[0]}, nil | ||||||
| 	} | 	} | ||||||
| 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ import ( | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
|  | @ -600,10 +601,15 @@ type ProcessCookieTestOpts struct { | ||||||
| 	providerValidateCookieResponse bool | 	providerValidateCookieResponse bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { | type OptionsModifier func(*Options) | ||||||
|  | 
 | ||||||
|  | func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest { | ||||||
| 	var pcTest ProcessCookieTest | 	var pcTest ProcessCookieTest | ||||||
| 
 | 
 | ||||||
| 	pcTest.opts = NewOptions() | 	pcTest.opts = NewOptions() | ||||||
|  | 	for _, modifier := range modifiers { | ||||||
|  | 		modifier(pcTest.opts) | ||||||
|  | 	} | ||||||
| 	pcTest.opts.ClientID = "bazquux" | 	pcTest.opts.ClientID = "bazquux" | ||||||
| 	pcTest.opts.ClientSecret = "xyzzyplugh" | 	pcTest.opts.ClientSecret = "xyzzyplugh" | ||||||
| 	pcTest.opts.CookieSecret = "0123456789abcdefabcd" | 	pcTest.opts.CookieSecret = "0123456789abcdefabcd" | ||||||
|  | @ -634,32 +640,34 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { | func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest { | ||||||
| 	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) | 	return NewProcessCookieTest(ProcessCookieTestOpts{ | ||||||
|  | 		providerValidateCookieResponse: true, | ||||||
|  | 	}, modifiers...) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { | func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error { | ||||||
| 	value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) | 	err := p.proxy.SaveSession(p.rw, p.req, s) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { | 	for _, cookie := range p.rw.Result().Cookies() { | ||||||
| 		p.req.AddCookie(c) | 		p.req.AddCookie(cookie) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { | func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) { | ||||||
| 	return p.proxy.LoadCookiedSession(p.req) | 	return p.proxy.LoadCookiedSession(p.req) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestLoadCookiedSession(t *testing.T) { | func TestLoadCookiedSession(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 
 | 
 | ||||||
| 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||||
| 	pcTest.SaveSession(startSession, time.Now()) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, startSession.Email, session.Email) | 	assert.Equal(t, startSession.Email, session.Email) | ||||||
| 	assert.Equal(t, "john.doe@example.com", session.User) | 	assert.Equal(t, "john.doe@example.com", session.User) | ||||||
|  | @ -669,7 +677,7 @@ func TestLoadCookiedSession(t *testing.T) { | ||||||
| func TestProcessCookieNoCookieError(t *testing.T) { | func TestProcessCookieNoCookieError(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) | 	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session. got %#v", session) | 		t.Errorf("expected nil session. got %#v", session) | ||||||
|  | @ -677,29 +685,31 @@ func TestProcessCookieNoCookieError(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieRefreshNotSet(t *testing.T) { | func TestProcessCookieRefreshNotSet(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour | 		opts.CookieExpire = time.Duration(23) * time.Hour | ||||||
|  | 	}) | ||||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||||
| 
 | 
 | ||||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	session, age, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	if age < time.Duration(-2)*time.Hour { | 	if session.Age() < time.Duration(-2)*time.Hour { | ||||||
| 		t.Errorf("cookie too young %v", age) | 		t.Errorf("cookie too young %v", session.Age()) | ||||||
| 	} | 	} | ||||||
| 	assert.Equal(t, startSession.Email, session.Email) | 	assert.Equal(t, startSession.Email, session.Email) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieFailIfCookieExpired(t *testing.T) { | func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||||
|  | 	}) | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session %#v", session) | 		t.Errorf("expected nil session %#v", session) | ||||||
|  | @ -707,22 +717,23 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||||
|  | 	}) | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	pcTest.proxy.CookieRefresh = time.Hour | 	pcTest.proxy.CookieRefresh = time.Hour | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session %#v", session) | 		t.Errorf("expected nil session %#v", session) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewAuthOnlyEndpointTest() *ProcessCookieTest { | func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...) | ||||||
| 	pcTest.req, _ = http.NewRequest("GET", | 	pcTest.req, _ = http.NewRequest("GET", | ||||||
| 		pcTest.opts.ProxyPrefix+"/auth", nil) | 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||||
| 	return pcTest | 	return pcTest | ||||||
|  | @ -731,8 +742,8 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { | ||||||
| func TestAuthOnlyEndpointAccepted(t *testing.T) { | func TestAuthOnlyEndpointAccepted(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest() | ||||||
| 	startSession := &sessions.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||||
| 	test.SaveSession(startSession, time.Now()) | 	test.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
| 	assert.Equal(t, http.StatusAccepted, test.rw.Code) | 	assert.Equal(t, http.StatusAccepted, test.rw.Code) | ||||||
|  | @ -750,12 +761,13 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest(func(opts *Options) { | ||||||
| 	test.proxy.CookieExpire = time.Duration(24) * time.Hour | 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||||
|  | 	}) | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &sessions.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||||
| 	test.SaveSession(startSession, reference) | 	test.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
| 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code) | 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code) | ||||||
|  | @ -766,8 +778,8 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | ||||||
| func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest() | ||||||
| 	startSession := &sessions.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||||
| 	test.SaveSession(startSession, time.Now()) | 	test.SaveSession(startSession) | ||||||
| 	test.validateUser = false | 	test.validateUser = false | ||||||
| 
 | 
 | ||||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
|  | @ -797,8 +809,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | ||||||
| 		pcTest.opts.ProxyPrefix+"/auth", nil) | 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||||
| 
 | 
 | ||||||
| 	startSession := &sessions.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} | 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()} | ||||||
| 	pcTest.SaveSession(startSession, time.Now()) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) | 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) | ||||||
| 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) | 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) | ||||||
|  | @ -930,11 +942,11 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { | ||||||
| 
 | 
 | ||||||
| 	state := &sessions.SessionState{ | 	state := &sessions.SessionState{ | ||||||
| 		Email: "mbland@acm.org", AccessToken: "my_access_token"} | 		Email: "mbland@acm.org", AccessToken: "my_access_token"} | ||||||
| 	value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) | 	err = proxy.SaveSession(st.rw, req, state) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
| 	for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) { | 	for _, c := range st.rw.Result().Cookies() { | ||||||
| 		req.AddCookie(c) | 		req.AddCookie(c) | ||||||
| 	} | 	} | ||||||
| 	// This is used by the upstream to validate the signature.
 | 	// This is used by the upstream to validate the signature.
 | ||||||
|  | @ -1068,7 +1080,12 @@ func TestAjaxForbiddendRequest(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestClearSplitCookie(t *testing.T) { | func TestClearSplitCookie(t *testing.T) { | ||||||
| 	p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} | 	opts := NewOptions() | ||||||
|  | 	opts.CookieName = "oauth2" | ||||||
|  | 	opts.CookieDomain = "abc" | ||||||
|  | 	store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions) | ||||||
|  | 	assert.Equal(t, err, nil) | ||||||
|  | 	p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} | ||||||
| 	var rw = httptest.NewRecorder() | 	var rw = httptest.NewRecorder() | ||||||
| 	req := httptest.NewRequest("get", "/", nil) | 	req := httptest.NewRequest("get", "/", nil) | ||||||
| 
 | 
 | ||||||
|  | @ -1092,7 +1109,12 @@ func TestClearSplitCookie(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestClearSingleCookie(t *testing.T) { | func TestClearSingleCookie(t *testing.T) { | ||||||
| 	p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} | 	opts := NewOptions() | ||||||
|  | 	opts.CookieName = "oauth2" | ||||||
|  | 	opts.CookieDomain = "abc" | ||||||
|  | 	store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions) | ||||||
|  | 	assert.Equal(t, err, nil) | ||||||
|  | 	p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} | ||||||
| 	var rw = httptest.NewRecorder() | 	var rw = httptest.NewRecorder() | ||||||
| 	req := httptest.NewRequest("get", "/", nil) | 	req := httptest.NewRequest("get", "/", nil) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										24
									
								
								options.go
								
								
								
								
							
							
						
						
									
										24
									
								
								options.go
								
								
								
								
							|  | @ -17,8 +17,11 @@ import ( | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	oidc "github.com/coreos/go-oidc" | ||||||
| 	"github.com/dgrijalva/jwt-go" | 	"github.com/dgrijalva/jwt-go" | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||||
|  | 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"gopkg.in/natefinch/lumberjack.v2" | 	"gopkg.in/natefinch/lumberjack.v2" | ||||||
| ) | ) | ||||||
|  | @ -111,6 +114,7 @@ type Options struct { | ||||||
| 	proxyURLs     []*url.URL | 	proxyURLs     []*url.URL | ||||||
| 	CompiledRegex []*regexp.Regexp | 	CompiledRegex []*regexp.Regexp | ||||||
| 	provider      providers.Provider | 	provider      providers.Provider | ||||||
|  | 	sessionStore  sessionsapi.SessionStore | ||||||
| 	signatureData *SignatureData | 	signatureData *SignatureData | ||||||
| 	oidcVerifier  *oidc.IDTokenVerifier | 	oidcVerifier  *oidc.IDTokenVerifier | ||||||
| } | } | ||||||
|  | @ -136,6 +140,9 @@ func NewOptions() *Options { | ||||||
| 			CookieExpire:   time.Duration(168) * time.Hour, | 			CookieExpire:   time.Duration(168) * time.Hour, | ||||||
| 			CookieRefresh:  time.Duration(0), | 			CookieRefresh:  time.Duration(0), | ||||||
| 		}, | 		}, | ||||||
|  | 		SessionOptions: options.SessionOptions{ | ||||||
|  | 			Type: "cookie", | ||||||
|  | 		}, | ||||||
| 		SetXAuthRequest:       false, | 		SetXAuthRequest:       false, | ||||||
| 		SkipAuthPreflight:     false, | 		SkipAuthPreflight:     false, | ||||||
| 		PassBasicAuth:         true, | 		PassBasicAuth:         true, | ||||||
|  | @ -261,7 +268,8 @@ func (o *Options) Validate() error { | ||||||
| 	} | 	} | ||||||
| 	msgs = parseProviderInfo(o, msgs) | 	msgs = parseProviderInfo(o, msgs) | ||||||
| 
 | 
 | ||||||
| 	if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { | 	var cipher *cookie.Cipher | ||||||
|  | 	if o.PassAccessToken || o.SetAuthorization || o.PassAuthorization || (o.CookieRefresh != time.Duration(0)) { | ||||||
| 		validCookieSecretSize := false | 		validCookieSecretSize := false | ||||||
| 		for _, i := range []int{16, 24, 32} { | 		for _, i := range []int{16, 24, 32} { | ||||||
| 			if len(secretBytes(o.CookieSecret)) == i { | 			if len(secretBytes(o.CookieSecret)) == i { | ||||||
|  | @ -283,8 +291,22 @@ func (o *Options) Validate() error { | ||||||
| 					"pass_access_token == true or "+ | 					"pass_access_token == true or "+ | ||||||
| 					"cookie_refresh != 0, but is %d bytes.%s", | 					"cookie_refresh != 0, but is %d bytes.%s", | ||||||
| 				len(secretBytes(o.CookieSecret)), suffix)) | 				len(secretBytes(o.CookieSecret)), suffix)) | ||||||
|  | 		} else { | ||||||
|  | 			var err error | ||||||
|  | 			cipher, err = cookie.NewCipher(secretBytes(o.CookieSecret)) | ||||||
|  | 			if err != nil { | ||||||
|  | 				msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err)) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	o.SessionOptions.Cipher = cipher | ||||||
|  | 	sessionStore, err := sessions.NewSessionStore(&o.SessionOptions, &o.CookieOptions) | ||||||
|  | 	if err != nil { | ||||||
|  | 		msgs = append(msgs, fmt.Sprintf("error initialising session storage: %v", err)) | ||||||
|  | 	} else { | ||||||
|  | 		o.sessionStore = sessionStore | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	if o.CookieRefresh >= o.CookieExpire { | 	if o.CookieRefresh >= o.CookieExpire { | ||||||
| 		msgs = append(msgs, fmt.Sprintf( | 		msgs = append(msgs, fmt.Sprintf( | ||||||
|  |  | ||||||
|  | @ -1,8 +1,13 @@ | ||||||
| package options | package options | ||||||
| 
 | 
 | ||||||
|  | import ( | ||||||
|  | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| // SessionOptions contains configuration options for the SessionStore providers.
 | // SessionOptions contains configuration options for the SessionStore providers.
 | ||||||
| type SessionOptions struct { | type SessionOptions struct { | ||||||
| 	Type   string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` | 	Type   string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` | ||||||
|  | 	Cipher *cookie.Cipher | ||||||
| 	CookieStoreOptions | 	CookieStoreOptions | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -14,6 +14,7 @@ import ( | ||||||
| type SessionState struct { | type SessionState struct { | ||||||
| 	AccessToken  string    `json:",omitempty"` | 	AccessToken  string    `json:",omitempty"` | ||||||
| 	IDToken      string    `json:",omitempty"` | 	IDToken      string    `json:",omitempty"` | ||||||
|  | 	CreatedAt    time.Time `json:"-"` | ||||||
| 	ExpiresOn    time.Time `json:"-"` | 	ExpiresOn    time.Time `json:"-"` | ||||||
| 	RefreshToken string    `json:",omitempty"` | 	RefreshToken string    `json:",omitempty"` | ||||||
| 	Email        string    `json:",omitempty"` | 	Email        string    `json:",omitempty"` | ||||||
|  | @ -23,6 +24,7 @@ type SessionState struct { | ||||||
| // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
 | // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
 | ||||||
| type SessionStateJSON struct { | type SessionStateJSON struct { | ||||||
| 	*SessionState | 	*SessionState | ||||||
|  | 	CreatedAt *time.Time `json:",omitempty"` | ||||||
| 	ExpiresOn *time.Time `json:",omitempty"` | 	ExpiresOn *time.Time `json:",omitempty"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool { | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Age returns the age of a session
 | ||||||
|  | func (s *SessionState) Age() time.Duration { | ||||||
|  | 	if !s.CreatedAt.IsZero() { | ||||||
|  | 		return time.Now().Truncate(time.Second).Sub(s.CreatedAt) | ||||||
|  | 	} | ||||||
|  | 	return 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // String constructs a summary of the session state
 | // String constructs a summary of the session state
 | ||||||
| func (s *SessionState) String() string { | func (s *SessionState) String() string { | ||||||
| 	o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) | 	o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) | ||||||
|  | @ -43,6 +53,9 @@ func (s *SessionState) String() string { | ||||||
| 	if s.IDToken != "" { | 	if s.IDToken != "" { | ||||||
| 		o += " id_token:true" | 		o += " id_token:true" | ||||||
| 	} | 	} | ||||||
|  | 	if !s.CreatedAt.IsZero() { | ||||||
|  | 		o += fmt.Sprintf(" created:%s", s.CreatedAt) | ||||||
|  | 	} | ||||||
| 	if !s.ExpiresOn.IsZero() { | 	if !s.ExpiresOn.IsZero() { | ||||||
| 		o += fmt.Sprintf(" expires:%s", s.ExpiresOn) | 		o += fmt.Sprintf(" expires:%s", s.ExpiresOn) | ||||||
| 	} | 	} | ||||||
|  | @ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { | ||||||
| 	} | 	} | ||||||
| 	// Embed SessionState and ExpiresOn pointer into SessionStateJSON
 | 	// Embed SessionState and ExpiresOn pointer into SessionStateJSON
 | ||||||
| 	ssj := &SessionStateJSON{SessionState: &ss} | 	ssj := &SessionStateJSON{SessionState: &ss} | ||||||
|  | 	if !ss.CreatedAt.IsZero() { | ||||||
|  | 		ssj.CreatedAt = &ss.CreatedAt | ||||||
|  | 	} | ||||||
| 	if !ss.ExpiresOn.IsZero() { | 	if !ss.ExpiresOn.IsZero() { | ||||||
| 		ssj.ExpiresOn = &ss.ExpiresOn | 		ssj.ExpiresOn = &ss.ExpiresOn | ||||||
| 	} | 	} | ||||||
|  | @ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { | ||||||
| 	var ss *SessionState | 	var ss *SessionState | ||||||
| 	err := json.Unmarshal([]byte(v), &ssj) | 	err := json.Unmarshal([]byte(v), &ssj) | ||||||
| 	if err == nil && ssj.SessionState != nil { | 	if err == nil && ssj.SessionState != nil { | ||||||
| 		// Extract SessionState and ExpiresOn value from SessionStateJSON
 | 		// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
 | ||||||
| 		ss = ssj.SessionState | 		ss = ssj.SessionState | ||||||
|  | 		if ssj.CreatedAt != nil { | ||||||
|  | 			ss.CreatedAt = *ssj.CreatedAt | ||||||
|  | 		} | ||||||
| 		if ssj.ExpiresOn != nil { | 		if ssj.ExpiresOn != nil { | ||||||
| 			ss.ExpiresOn = *ssj.ExpiresOn | 			ss.ExpiresOn = *ssj.ExpiresOn | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
| 		IDToken:      "rawtoken1234", | 		IDToken:      "rawtoken1234", | ||||||
|  | 		CreatedAt:    time.Now(), | ||||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||||
| 		RefreshToken: "refresh4321", | 		RefreshToken: "refresh4321", | ||||||
| 	} | 	} | ||||||
|  | @ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	assert.Equal(t, s.Email, ss.Email) | 	assert.Equal(t, s.Email, ss.Email) | ||||||
| 	assert.Equal(t, s.AccessToken, ss.AccessToken) | 	assert.Equal(t, s.AccessToken, ss.AccessToken) | ||||||
| 	assert.Equal(t, s.IDToken, ss.IDToken) | 	assert.Equal(t, s.IDToken, ss.IDToken) | ||||||
|  | 	assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) | ||||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||||
| 
 | 
 | ||||||
|  | @ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.NotEqual(t, "user@domain.com", ss.User) | 	assert.NotEqual(t, "user@domain.com", ss.User) | ||||||
| 	assert.NotEqual(t, s.Email, ss.Email) | 	assert.NotEqual(t, s.Email, ss.Email) | ||||||
|  | 	assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) | ||||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||||
| 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | ||||||
| 	assert.NotEqual(t, s.IDToken, ss.IDToken) | 	assert.NotEqual(t, s.IDToken, ss.IDToken) | ||||||
|  | @ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 		User:         "just-user", | 		User:         "just-user", | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
|  | 		CreatedAt:    time.Now(), | ||||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||||
| 		RefreshToken: "refresh4321", | 		RefreshToken: "refresh4321", | ||||||
| 	} | 	} | ||||||
|  | @ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	assert.Equal(t, s.User, ss.User) | 	assert.Equal(t, s.User, ss.User) | ||||||
| 	assert.Equal(t, s.Email, ss.Email) | 	assert.Equal(t, s.Email, ss.Email) | ||||||
| 	assert.Equal(t, s.AccessToken, ss.AccessToken) | 	assert.Equal(t, s.AccessToken, ss.AccessToken) | ||||||
|  | 	assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) | ||||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||||
| 
 | 
 | ||||||
|  | @ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.NotEqual(t, s.User, ss.User) | 	assert.NotEqual(t, s.User, ss.User) | ||||||
| 	assert.NotEqual(t, s.Email, ss.Email) | 	assert.NotEqual(t, s.Email, ss.Email) | ||||||
|  | 	assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) | ||||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||||
| 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | ||||||
| 	assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) | 	assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) | ||||||
|  | @ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||||
| 	s := &sessions.SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
|  | 		CreatedAt:    time.Now(), | ||||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||||
| 		RefreshToken: "refresh4321", | 		RefreshToken: "refresh4321", | ||||||
| 	} | 	} | ||||||
|  | @ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||||
| 		User:         "just-user", | 		User:         "just-user", | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
|  | 		CreatedAt:    time.Now(), | ||||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||||
| 		RefreshToken: "refresh4321", | 		RefreshToken: "refresh4321", | ||||||
| 	} | 	} | ||||||
|  | @ -147,6 +155,7 @@ type testCase struct { | ||||||
| // Currently only tests without cipher here because we have no way to mock
 | // Currently only tests without cipher here because we have no way to mock
 | ||||||
| // the random generator used in EncodeSessionState.
 | // the random generator used in EncodeSessionState.
 | ||||||
| func TestEncodeSessionState(t *testing.T) { | func TestEncodeSessionState(t *testing.T) { | ||||||
|  | 	c := time.Now() | ||||||
| 	e := time.Now().Add(time.Duration(1) * time.Hour) | 	e := time.Now().Add(time.Duration(1) * time.Hour) | ||||||
| 
 | 
 | ||||||
| 	testCases := []testCase{ | 	testCases := []testCase{ | ||||||
|  | @ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) { | ||||||
| 				User:         "just-user", | 				User:         "just-user", | ||||||
| 				AccessToken:  "token1234", | 				AccessToken:  "token1234", | ||||||
| 				IDToken:      "rawtoken1234", | 				IDToken:      "rawtoken1234", | ||||||
|  | 				CreatedAt:    c, | ||||||
| 				ExpiresOn:    e, | 				ExpiresOn:    e, | ||||||
| 				RefreshToken: "refresh4321", | 				RefreshToken: "refresh4321", | ||||||
| 			}, | 			}, | ||||||
|  | @ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| // TestDecodeSessionState testssessions.DecodeSessionState with the test vector
 | // TestDecodeSessionState testssessions.DecodeSessionState with the test vector
 | ||||||
| func TestDecodeSessionState(t *testing.T) { | func TestDecodeSessionState(t *testing.T) { | ||||||
|  | 	created := time.Now() | ||||||
|  | 	createdJSON, _ := created.MarshalJSON() | ||||||
|  | 	createdString := string(createdJSON) | ||||||
| 	e := time.Now().Add(time.Duration(1) * time.Hour) | 	e := time.Now().Add(time.Duration(1) * time.Hour) | ||||||
| 	eJSON, _ := e.MarshalJSON() | 	eJSON, _ := e.MarshalJSON() | ||||||
| 	eString := string(eJSON) | 	eString := string(eJSON) | ||||||
|  | @ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), | 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: sessions.SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
|  | @ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 				User:         "just-user", | 				User:         "just-user", | ||||||
| 				AccessToken:  "token1234", | 				AccessToken:  "token1234", | ||||||
| 				IDToken:      "rawtoken1234", | 				IDToken:      "rawtoken1234", | ||||||
|  | 				CreatedAt:    created, | ||||||
| 				ExpiresOn:    e, | 				ExpiresOn:    e, | ||||||
| 				RefreshToken: "refresh4321", | 				RefreshToken: "refresh4321", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), | 			Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), | ||||||
| 			Cipher:  c, | 			Cipher:  c, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
|  | @ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestSessionStateAge(t *testing.T) { | ||||||
|  | 	ss := &sessions.SessionState{} | ||||||
|  | 
 | ||||||
|  | 	// Created at unset so should be 0
 | ||||||
|  | 	assert.Equal(t, time.Duration(0), ss.Age()) | ||||||
|  | 
 | ||||||
|  | 	// Set CreatedAt to 1 hour ago
 | ||||||
|  | 	ss.CreatedAt = time.Now().Add(-1 * time.Hour) | ||||||
|  | 	assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // MakeCookie constructs a cookie from the given parameters,
 | // MakeCookie constructs a cookie from the given parameters,
 | ||||||
|  | @ -32,3 +33,9 @@ func MakeCookie(req *http.Request, name string, value string, path string, domai | ||||||
| 		Expires:  now.Add(expiration), | 		Expires:  now.Add(expiration), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // MakeCookieFromOptions constructs a cookie based on the givemn *options.CookieOptions,
 | ||||||
|  | // value and creation time
 | ||||||
|  | func MakeCookieFromOptions(req *http.Request, name string, value string, opts *options.CookieOptions, expiration time.Duration, now time.Time) *http.Cookie { | ||||||
|  | 	return MakeCookie(req, name, value, opts.CookiePath, opts.CookieDomain, opts.CookieHTTPOnly, opts.CookieSecure, expiration, now) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -27,36 +27,33 @@ var _ sessions.SessionStore = &SessionStore{} | ||||||
| // SessionStore is an implementation of the sessions.SessionStore
 | // SessionStore is an implementation of the sessions.SessionStore
 | ||||||
| // interface that stores sessions in client side cookies
 | // interface that stores sessions in client side cookies
 | ||||||
| type SessionStore struct { | type SessionStore struct { | ||||||
|  | 	CookieOptions *options.CookieOptions | ||||||
| 	CookieCipher  *cookie.Cipher | 	CookieCipher  *cookie.Cipher | ||||||
| 	CookieDomain   string |  | ||||||
| 	CookieExpire   time.Duration |  | ||||||
| 	CookieHTTPOnly bool |  | ||||||
| 	CookieName     string |  | ||||||
| 	CookiePath     string |  | ||||||
| 	CookieSecret   string |  | ||||||
| 	CookieSecure   bool |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Save takes a sessions.SessionState and stores the information from it
 | // Save takes a sessions.SessionState and stores the information from it
 | ||||||
| // within Cookies set on the HTTP response writer
 | // within Cookies set on the HTTP response writer
 | ||||||
| func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | ||||||
|  | 	if ss.CreatedAt.IsZero() { | ||||||
|  | 		ss.CreatedAt = time.Now() | ||||||
|  | 	} | ||||||
| 	value, err := utils.CookieForSession(ss, s.CookieCipher) | 	value, err := utils.CookieForSession(ss, s.CookieCipher) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	s.setSessionCookie(rw, req, value) | 	s.setSessionCookie(rw, req, value, ss.CreatedAt) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Load reads sessions.SessionState information from Cookies within the
 | // Load reads sessions.SessionState information from Cookies within the
 | ||||||
| // HTTP request object
 | // HTTP request object
 | ||||||
| func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { | func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { | ||||||
| 	c, err := loadCookie(req, s.CookieName) | 	c, err := loadCookie(req, s.CookieOptions.CookieName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// always http.ErrNoCookie
 | 		// always http.ErrNoCookie
 | ||||||
| 		return nil, fmt.Errorf("Cookie %q not present", s.CookieName) | 		return nil, fmt.Errorf("Cookie %q not present", s.CookieOptions.CookieName) | ||||||
| 	} | 	} | ||||||
| 	val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire) | 	val, _, ok := cookie.Validate(c, s.CookieOptions.CookieSecret, s.CookieOptions.CookieExpire) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, errors.New("Cookie Signature not valid") | 		return nil, errors.New("Cookie Signature not valid") | ||||||
| 	} | 	} | ||||||
|  | @ -74,11 +71,11 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||||
| 	var cookies []*http.Cookie | 	var cookies []*http.Cookie | ||||||
| 
 | 
 | ||||||
| 	// matches CookieName, CookieName_<number>
 | 	// matches CookieName, CookieName_<number>
 | ||||||
| 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName)) | 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieOptions.CookieName)) | ||||||
| 
 | 
 | ||||||
| 	for _, c := range req.Cookies() { | 	for _, c := range req.Cookies() { | ||||||
| 		if cookieNameRegex.MatchString(c.Name) { | 		if cookieNameRegex.MatchString(c.Name) { | ||||||
| 			clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1) | 			clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) | ||||||
| 
 | 
 | ||||||
| 			http.SetCookie(rw, clearCookie) | 			http.SetCookie(rw, clearCookie) | ||||||
| 			cookies = append(cookies, clearCookie) | 			cookies = append(cookies, clearCookie) | ||||||
|  | @ -89,60 +86,42 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // setSessionCookie adds the user's session cookie to the response
 | // setSessionCookie adds the user's session cookie to the response
 | ||||||
| func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) { | ||||||
| 	for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { | 	for _, c := range s.makeSessionCookie(req, val, created) { | ||||||
| 		http.SetCookie(rw, c) | 		http.SetCookie(rw, c) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // makeSessionCookie creates an http.Cookie containing the authenticated user's
 | // makeSessionCookie creates an http.Cookie containing the authenticated user's
 | ||||||
| // authentication details
 | // authentication details
 | ||||||
| func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { | func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie { | ||||||
| 	if value != "" { | 	if value != "" { | ||||||
| 		value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now) | 		value = cookie.SignedValue(s.CookieOptions.CookieSecret, s.CookieOptions.CookieName, value, now) | ||||||
| 	} | 	} | ||||||
| 	c := s.makeCookie(req, s.CookieName, value, expiration) | 	c := s.makeCookie(req, s.CookieOptions.CookieName, value, s.CookieOptions.CookieExpire, now) | ||||||
| 	if len(c.Value) > 4096-len(s.CookieName) { | 	if len(c.Value) > 4096-len(s.CookieOptions.CookieName) { | ||||||
| 		return splitCookie(c) | 		return splitCookie(c) | ||||||
| 	} | 	} | ||||||
| 	return []*http.Cookie{c} | 	return []*http.Cookie{c} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie { | func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||||
| 	return cookies.MakeCookie( | 	return cookies.MakeCookieFromOptions( | ||||||
| 		req, | 		req, | ||||||
| 		name, | 		name, | ||||||
| 		value, | 		value, | ||||||
| 		s.CookiePath, | 		s.CookieOptions, | ||||||
| 		s.CookieDomain, |  | ||||||
| 		s.CookieHTTPOnly, |  | ||||||
| 		s.CookieSecure, |  | ||||||
| 		expiration, | 		expiration, | ||||||
| 		time.Now(), | 		now, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewCookieSessionStore initialises a new instance of the SessionStore from
 | // NewCookieSessionStore initialises a new instance of the SessionStore from
 | ||||||
| // the configuration given
 | // the configuration given
 | ||||||
| func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||||
| 	var cipher *cookie.Cipher |  | ||||||
| 	if len(cookieOpts.CookieSecret) > 0 { |  | ||||||
| 		var err error |  | ||||||
| 		cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, fmt.Errorf("unable to create cipher: %v", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return &SessionStore{ | 	return &SessionStore{ | ||||||
| 		CookieCipher:   cipher, | 		CookieCipher:  opts.Cipher, | ||||||
| 		CookieDomain:   cookieOpts.CookieDomain, | 		CookieOptions: cookieOpts, | ||||||
| 		CookieExpire:   cookieOpts.CookieExpire, |  | ||||||
| 		CookieHTTPOnly: cookieOpts.CookieHTTPOnly, |  | ||||||
| 		CookieName:     cookieOpts.CookieName, |  | ||||||
| 		CookiePath:     cookieOpts.CookiePath, |  | ||||||
| 		CookieSecret:   cookieOpts.CookieSecret, |  | ||||||
| 		CookieSecure:   cookieOpts.CookieSecure, |  | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -12,7 +12,7 @@ import ( | ||||||
| func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||||
| 	switch opts.Type { | 	switch opts.Type { | ||||||
| 	case options.CookieSessionStoreType: | 	case options.CookieSessionStoreType: | ||||||
| 		return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) | 		return cookie.NewCookieSessionStore(opts, cookieOpts) | ||||||
| 	default: | 	default: | ||||||
| 		return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) | 		return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -5,16 +5,20 @@ import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||||
| 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/cookies" | 	"github.com/pusher/oauth2_proxy/pkg/cookies" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions" | 	"github.com/pusher/oauth2_proxy/pkg/sessions" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | 	sessionscookie "github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions/utils" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestSessionStore(t *testing.T) { | func TestSessionStore(t *testing.T) { | ||||||
|  | @ -72,6 +76,16 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 				} | 				} | ||||||
| 			}) | 			}) | ||||||
| 
 | 
 | ||||||
|  | 			It("have a signature timestamp matching session.CreatedAt", func() { | ||||||
|  | 				for _, cookie := range cookies { | ||||||
|  | 					if cookie.Value != "" { | ||||||
|  | 						parts := strings.Split(cookie.Value, "|") | ||||||
|  | 						Expect(parts).To(HaveLen(3)) | ||||||
|  | 						Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix())))) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -86,6 +100,10 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 				Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) | 				Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) | ||||||
| 			}) | 			}) | ||||||
| 
 | 
 | ||||||
|  | 			It("Ensures the session CreatedAt is not zero", func() { | ||||||
|  | 				Expect(session.CreatedAt.IsZero()).To(BeFalse()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
| 			CheckCookieOptions() | 			CheckCookieOptions() | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
|  | @ -138,12 +156,15 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 
 | 
 | ||||||
| 					// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
 | 					// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
 | ||||||
| 					l := *loadedSession | 					l := *loadedSession | ||||||
|  | 					l.CreatedAt = time.Time{} | ||||||
| 					l.ExpiresOn = time.Time{} | 					l.ExpiresOn = time.Time{} | ||||||
| 					s := *session | 					s := *session | ||||||
|  | 					s.CreatedAt = time.Time{} | ||||||
| 					s.ExpiresOn = time.Time{} | 					s.ExpiresOn = time.Time{} | ||||||
| 					Expect(l).To(Equal(s)) | 					Expect(l).To(Equal(s)) | ||||||
| 
 | 
 | ||||||
| 					// Compare time.Time separately
 | 					// Compare time.Time separately
 | ||||||
|  | 					Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue()) | ||||||
| 					Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) | 					Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) | ||||||
| 				} | 				} | ||||||
| 			}) | 			}) | ||||||
|  | @ -181,12 +202,16 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 			SessionStoreInterfaceTests() | 			SessionStoreInterfaceTests() | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
| 		Context("with a cookie-secret set", func() { | 		Context("with a cipher", func() { | ||||||
| 			BeforeEach(func() { | 			BeforeEach(func() { | ||||||
| 				secret := make([]byte, 32) | 				secret := make([]byte, 32) | ||||||
| 				_, err := rand.Read(secret) | 				_, err := rand.Read(secret) | ||||||
| 				Expect(err).ToNot(HaveOccurred()) | 				Expect(err).ToNot(HaveOccurred()) | ||||||
| 				cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) | 				cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) | ||||||
|  | 				cipher, err := cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(cipher).ToNot(BeNil()) | ||||||
|  | 				opts.Cipher = cipher | ||||||
| 
 | 
 | ||||||
| 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||||
| 				Expect(err).ToNot(HaveOccurred()) | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | @ -231,7 +256,7 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 		It("creates a cookie.SessionStore", func() { | 		It("creates a cookie.SessionStore", func() { | ||||||
| 			ss, err := sessions.NewSessionStore(opts, cookieOpts) | 			ss, err := sessions.NewSessionStore(opts, cookieOpts) | ||||||
| 			Expect(err).NotTo(HaveOccurred()) | 			Expect(err).NotTo(HaveOccurred()) | ||||||
| 			Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{})) | 			Expect(ss).To(BeAssignableToTypeOf(&sessionscookie.SessionStore{})) | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
| 		Context("the cookie.SessionStore", func() { | 		Context("the cookie.SessionStore", func() { | ||||||
|  |  | ||||||
|  | @ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt | ||||||
| 	s = &sessions.SessionState{ | 	s = &sessions.SessionState{ | ||||||
| 		AccessToken:  jsonResponse.AccessToken, | 		AccessToken:  jsonResponse.AccessToken, | ||||||
| 		IDToken:      jsonResponse.IDToken, | 		IDToken:      jsonResponse.IDToken, | ||||||
|  | 		CreatedAt:    time.Now(), | ||||||
| 		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | 		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||||
| 		RefreshToken: jsonResponse.RefreshToken, | 		RefreshToken: jsonResponse.RefreshToken, | ||||||
| 		Email:        c.Email, | 		Email:        c.Email, | ||||||
|  |  | ||||||
|  | @ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session | ||||||
| 	s = &sessions.SessionState{ | 	s = &sessions.SessionState{ | ||||||
| 		AccessToken: jsonResponse.AccessToken, | 		AccessToken: jsonResponse.AccessToken, | ||||||
| 		IDToken:     jsonResponse.IDToken, | 		IDToken:     jsonResponse.IDToken, | ||||||
|  | 		CreatedAt:   time.Now(), | ||||||
| 		ExpiresOn:   time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | 		ExpiresOn:   time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||||
| 		Email:       email, | 		Email:       email, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) | ||||||
| 	s.AccessToken = newSession.AccessToken | 	s.AccessToken = newSession.AccessToken | ||||||
| 	s.IDToken = newSession.IDToken | 	s.IDToken = newSession.IDToken | ||||||
| 	s.RefreshToken = newSession.RefreshToken | 	s.RefreshToken = newSession.RefreshToken | ||||||
|  | 	s.CreatedAt = newSession.CreatedAt | ||||||
| 	s.ExpiresOn = newSession.ExpiresOn | 	s.ExpiresOn = newSession.ExpiresOn | ||||||
| 	s.Email = newSession.Email | 	s.Email = newSession.Email | ||||||
| 	return | 	return | ||||||
|  | @ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | ||||||
| 		AccessToken:  token.AccessToken, | 		AccessToken:  token.AccessToken, | ||||||
| 		IDToken:      rawIDToken, | 		IDToken:      rawIDToken, | ||||||
| 		RefreshToken: token.RefreshToken, | 		RefreshToken: token.RefreshToken, | ||||||
|  | 		CreatedAt:    time.Now(), | ||||||
| 		ExpiresOn:    token.Expiry, | 		ExpiresOn:    token.Expiry, | ||||||
| 		Email:        claims.Email, | 		Email:        claims.Email, | ||||||
| 		User:         claims.Subject, | 		User:         claims.Subject, | ||||||
|  |  | ||||||
|  | @ -8,6 +8,7 @@ import ( | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | @ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if a := v.Get("access_token"); a != "" { | 	if a := v.Get("access_token"); a != "" { | ||||||
| 		s = &sessions.SessionState{AccessToken: a} | 		s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()} | ||||||
| 	} else { | 	} else { | ||||||
| 		err = fmt.Errorf("no access token found %s", body) | 		err = fmt.Errorf("no access token found %s", body) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue