Merge pull request #148 from pusher/proxy-session-store
Proxy session store
This commit is contained in:
		
						commit
						10e240c8bf
					
				|  | @ -10,6 +10,7 @@ | |||
| 
 | ||||
| ## Changes since v3.2.0 | ||||
| 
 | ||||
| - [#148](https://github.com/pusher/outh2_proxy/pull/148) Implement SessionStore interface within proxy (@JoelSpeed) | ||||
| - [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed) | ||||
|   - Allows for multiple different session storage implementations including client and server side | ||||
|   - Adds tests suite for interface to ensure consistency across implementations | ||||
|  |  | |||
							
								
								
									
										191
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										191
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -16,7 +16,7 @@ import ( | |||
| 	"github.com/mbland/hmacauth" | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/providers" | ||||
| 	"github.com/yhat/wsutil" | ||||
| ) | ||||
|  | @ -29,10 +29,6 @@ const ( | |||
| 	httpScheme  = "http" | ||||
| 	httpsScheme = "https" | ||||
| 
 | ||||
| 	// Cookies are limited to 4kb including the length of the cookie name,
 | ||||
| 	// the cookie name can be up to 256 bytes
 | ||||
| 	maxCookieLength = 3840 | ||||
| 
 | ||||
| 	applicationJSON = "application/json" | ||||
| ) | ||||
| 
 | ||||
|  | @ -75,6 +71,7 @@ type OAuthProxy struct { | |||
| 	redirectURL         *url.URL // the url to receive requests at
 | ||||
| 	whitelistDomains    []string | ||||
| 	provider            providers.Provider | ||||
| 	sessionStore        sessionsapi.SessionStore | ||||
| 	ProxyPrefix         string | ||||
| 	SignInMessage       string | ||||
| 	HtpasswdFile        *HtpasswdFile | ||||
|  | @ -88,7 +85,6 @@ type OAuthProxy struct { | |||
| 	PassAccessToken     bool | ||||
| 	SetAuthorization    bool | ||||
| 	PassAuthorization   bool | ||||
| 	CookieCipher        *cookie.Cipher | ||||
| 	skipAuthRegex       []string | ||||
| 	skipAuthPreflight   bool | ||||
| 	compiledRegex       []*regexp.Regexp | ||||
|  | @ -218,15 +214,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | |||
| 
 | ||||
| 	logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s path:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, opts.CookiePath, refresh) | ||||
| 
 | ||||
| 	var cipher *cookie.Cipher | ||||
| 	if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) { | ||||
| 		var err error | ||||
| 		cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) | ||||
| 		if err != nil { | ||||
| 			logger.Fatal("cookie-secret error: ", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return &OAuthProxy{ | ||||
| 		CookieName:     opts.CookieName, | ||||
| 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), | ||||
|  | @ -249,6 +236,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | |||
| 
 | ||||
| 		ProxyPrefix:        opts.ProxyPrefix, | ||||
| 		provider:           opts.provider, | ||||
| 		sessionStore:       opts.sessionStore, | ||||
| 		serveMux:           serveMux, | ||||
| 		redirectURL:        redirectURL, | ||||
| 		whitelistDomains:   opts.WhitelistDomains, | ||||
|  | @ -263,7 +251,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | |||
| 		SetAuthorization:   opts.SetAuthorization, | ||||
| 		PassAuthorization:  opts.PassAuthorization, | ||||
| 		SkipProviderButton: opts.SkipProviderButton, | ||||
| 		CookieCipher:       cipher, | ||||
| 		templates:          loadTemplates(opts.CustomTemplatesDir), | ||||
| 		Footer:             opts.Footer, | ||||
| 	} | ||||
|  | @ -293,7 +280,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { | |||
| 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { | ||||
| func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("missing code") | ||||
| 	} | ||||
|  | @ -316,104 +303,6 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, er | |||
| 	return | ||||
| } | ||||
| 
 | ||||
| // MakeSessionCookie creates an http.Cookie containing the authenticated user's
 | ||||
| // authentication details
 | ||||
| func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { | ||||
| 	if value != "" { | ||||
| 		value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) | ||||
| 	} | ||||
| 	c := p.makeCookie(req, p.CookieName, value, expiration, now) | ||||
| 	if len(c.Value) > 4096-len(p.CookieName) { | ||||
| 		return splitCookie(c) | ||||
| 	} | ||||
| 	return []*http.Cookie{c} | ||||
| } | ||||
| 
 | ||||
| func copyCookie(c *http.Cookie) *http.Cookie { | ||||
| 	return &http.Cookie{ | ||||
| 		Name:       c.Name, | ||||
| 		Value:      c.Value, | ||||
| 		Path:       c.Path, | ||||
| 		Domain:     c.Domain, | ||||
| 		Expires:    c.Expires, | ||||
| 		RawExpires: c.RawExpires, | ||||
| 		MaxAge:     c.MaxAge, | ||||
| 		Secure:     c.Secure, | ||||
| 		HttpOnly:   c.HttpOnly, | ||||
| 		Raw:        c.Raw, | ||||
| 		Unparsed:   c.Unparsed, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // splitCookie reads the full cookie generated to store the session and splits
 | ||||
| // it into a slice of cookies which fit within the 4kb cookie limit indexing
 | ||||
| // the cookies from 0
 | ||||
| func splitCookie(c *http.Cookie) []*http.Cookie { | ||||
| 	if len(c.Value) < maxCookieLength { | ||||
| 		return []*http.Cookie{c} | ||||
| 	} | ||||
| 	cookies := []*http.Cookie{} | ||||
| 	valueBytes := []byte(c.Value) | ||||
| 	count := 0 | ||||
| 	for len(valueBytes) > 0 { | ||||
| 		new := copyCookie(c) | ||||
| 		new.Name = fmt.Sprintf("%s_%d", c.Name, count) | ||||
| 		count++ | ||||
| 		if len(valueBytes) < maxCookieLength { | ||||
| 			new.Value = string(valueBytes) | ||||
| 			valueBytes = []byte{} | ||||
| 		} else { | ||||
| 			newValue := valueBytes[:maxCookieLength] | ||||
| 			valueBytes = valueBytes[maxCookieLength:] | ||||
| 			new.Value = string(newValue) | ||||
| 		} | ||||
| 		cookies = append(cookies, new) | ||||
| 	} | ||||
| 	return cookies | ||||
| } | ||||
| 
 | ||||
| // joinCookies takes a slice of cookies from the request and reconstructs the
 | ||||
| // full session cookie
 | ||||
| func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { | ||||
| 	if len(cookies) == 0 { | ||||
| 		return nil, fmt.Errorf("list of cookies must be > 0") | ||||
| 	} | ||||
| 	if len(cookies) == 1 { | ||||
| 		return cookies[0], nil | ||||
| 	} | ||||
| 	c := copyCookie(cookies[0]) | ||||
| 	for i := 1; i < len(cookies); i++ { | ||||
| 		c.Value += cookies[i].Value | ||||
| 	} | ||||
| 	c.Name = strings.TrimRight(c.Name, "_0") | ||||
| 	return c, nil | ||||
| } | ||||
| 
 | ||||
| // loadCookie retreieves the sessions state cookie from the http request.
 | ||||
| // If a single cookie is present this will be returned, otherwise it attempts
 | ||||
| // to reconstruct a cookie split up by splitCookie
 | ||||
| func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { | ||||
| 	c, err := req.Cookie(cookieName) | ||||
| 	if err == nil { | ||||
| 		return c, nil | ||||
| 	} | ||||
| 	cookies := []*http.Cookie{} | ||||
| 	err = nil | ||||
| 	count := 0 | ||||
| 	for err == nil { | ||||
| 		var c *http.Cookie | ||||
| 		c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count)) | ||||
| 		if err == nil { | ||||
| 			cookies = append(cookies, c) | ||||
| 			count++ | ||||
| 		} | ||||
| 	} | ||||
| 	if len(cookies) == 0 { | ||||
| 		return nil, fmt.Errorf("Could not find cookie %s", cookieName) | ||||
| 	} | ||||
| 	return joinCookies(cookies) | ||||
| } | ||||
| 
 | ||||
| // MakeCSRFCookie creates a cookie for CSRF
 | ||||
| func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||
| 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) | ||||
|  | @ -454,66 +343,18 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va | |||
| 
 | ||||
| // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | ||||
| // stored in the user's session
 | ||||
| func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { | ||||
| 	var cookies []*http.Cookie | ||||
| 
 | ||||
| 	// matches CookieName, CookieName_<number>
 | ||||
| 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName)) | ||||
| 
 | ||||
| 	for _, c := range req.Cookies() { | ||||
| 		if cookieNameRegex.MatchString(c.Name) { | ||||
| 			clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) | ||||
| 
 | ||||
| 			http.SetCookie(rw, clearCookie) | ||||
| 			cookies = append(cookies, clearCookie) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// ugly hack because default domain changed
 | ||||
| 	if p.CookieDomain == "" && len(cookies) > 0 { | ||||
| 		clr2 := *cookies[0] | ||||
| 		clr2.Domain = req.Host | ||||
| 		http.SetCookie(rw, &clr2) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SetSessionCookie adds the user's session cookie to the response
 | ||||
| func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | ||||
| 	for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) { | ||||
| 		http.SetCookie(rw, c) | ||||
| 	} | ||||
| func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { | ||||
| 	return p.sessionStore.Clear(rw, req) | ||||
| } | ||||
| 
 | ||||
| // LoadCookiedSession reads the user's authentication details from the request
 | ||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { | ||||
| 	var age time.Duration | ||||
| 	c, err := loadCookie(req, p.CookieName) | ||||
| 	if err != nil { | ||||
| 		// always http.ErrNoCookie
 | ||||
| 		return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) | ||||
| 	} | ||||
| 	val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) | ||||
| 	if !ok { | ||||
| 		return nil, age, errors.New("Cookie Signature not valid") | ||||
| 	} | ||||
| 
 | ||||
| 	session, err := p.provider.SessionFromCookie(val, p.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return nil, age, err | ||||
| 	} | ||||
| 
 | ||||
| 	age = time.Now().Truncate(time.Second).Sub(timestamp) | ||||
| 	return session, age, nil | ||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	return p.sessionStore.Load(req) | ||||
| } | ||||
| 
 | ||||
| // SaveSession creates a new session cookie value and sets this on the response
 | ||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | ||||
| 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	p.SetSessionCookie(rw, req, value) | ||||
| 	return nil | ||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { | ||||
| 	return p.sessionStore.Save(rw, req, s) | ||||
| } | ||||
| 
 | ||||
| // RobotsTxt disallows scraping pages from the OAuthProxy
 | ||||
|  | @ -694,7 +535,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | |||
| 
 | ||||
| 	user, ok := p.ManualSignIn(rw, req) | ||||
| 	if ok { | ||||
| 		session := &sessions.SessionState{User: user} | ||||
| 		session := &sessionsapi.SessionState{User: user} | ||||
| 		p.SaveSession(rw, req, session) | ||||
| 		http.Redirect(rw, req, redirect, 302) | ||||
| 	} else { | ||||
|  | @ -833,12 +674,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | |||
| 	var saveSession, clearSession, revalidated bool | ||||
| 	remoteAddr := getRemoteAddr(req) | ||||
| 
 | ||||
| 	session, sessionAge, err := p.LoadCookiedSession(req) | ||||
| 	session, err := p.LoadCookiedSession(req) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("Error loading cookied session: %s", err) | ||||
| 	} | ||||
| 	if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { | ||||
| 		logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", sessionAge, session, p.CookieRefresh) | ||||
| 	if session != nil && session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { | ||||
| 		logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh) | ||||
| 		saveSession = true | ||||
| 	} | ||||
| 
 | ||||
|  | @ -945,7 +786,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | |||
| 
 | ||||
| // CheckBasicAuth checks the requests Authorization header for basic auth
 | ||||
| // credentials and authenticates these against the proxies HtpasswdFile
 | ||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { | ||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	if p.HtpasswdFile == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | @ -967,7 +808,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, | |||
| 	} | ||||
| 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | ||||
| 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||
| 		return &sessions.SessionState{User: pair[0]}, nil | ||||
| 		return &sessionsapi.SessionState{User: pair[0]}, nil | ||||
| 	} | ||||
| 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | ||||
| 	return nil, nil | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ import ( | |||
| 	"github.com/mbland/hmacauth" | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/providers" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
|  | @ -600,10 +601,15 @@ type ProcessCookieTestOpts struct { | |||
| 	providerValidateCookieResponse bool | ||||
| } | ||||
| 
 | ||||
| func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { | ||||
| type OptionsModifier func(*Options) | ||||
| 
 | ||||
| func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest { | ||||
| 	var pcTest ProcessCookieTest | ||||
| 
 | ||||
| 	pcTest.opts = NewOptions() | ||||
| 	for _, modifier := range modifiers { | ||||
| 		modifier(pcTest.opts) | ||||
| 	} | ||||
| 	pcTest.opts.ClientID = "bazquux" | ||||
| 	pcTest.opts.ClientSecret = "xyzzyplugh" | ||||
| 	pcTest.opts.CookieSecret = "0123456789abcdefabcd" | ||||
|  | @ -634,32 +640,34 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { | |||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { | ||||
| 	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) | ||||
| func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest { | ||||
| 	return NewProcessCookieTest(ProcessCookieTestOpts{ | ||||
| 		providerValidateCookieResponse: true, | ||||
| 	}, modifiers...) | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { | ||||
| 	value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) | ||||
| func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error { | ||||
| 	err := p.proxy.SaveSession(p.rw, p.req, s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { | ||||
| 		p.req.AddCookie(c) | ||||
| 	for _, cookie := range p.rw.Result().Cookies() { | ||||
| 		p.req.AddCookie(cookie) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { | ||||
| func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) { | ||||
| 	return p.proxy.LoadCookiedSession(p.req) | ||||
| } | ||||
| 
 | ||||
| func TestLoadCookiedSession(t *testing.T) { | ||||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| 
 | ||||
| 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} | ||||
| 	pcTest.SaveSession(startSession, time.Now()) | ||||
| 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||
| 	pcTest.SaveSession(startSession) | ||||
| 
 | ||||
| 	session, _, err := pcTest.LoadCookiedSession() | ||||
| 	session, err := pcTest.LoadCookiedSession() | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, startSession.Email, session.Email) | ||||
| 	assert.Equal(t, "john.doe@example.com", session.User) | ||||
|  | @ -669,7 +677,7 @@ func TestLoadCookiedSession(t *testing.T) { | |||
| func TestProcessCookieNoCookieError(t *testing.T) { | ||||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| 
 | ||||
| 	session, _, err := pcTest.LoadCookiedSession() | ||||
| 	session, err := pcTest.LoadCookiedSession() | ||||
| 	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expected nil session. got %#v", session) | ||||
|  | @ -677,29 +685,31 @@ func TestProcessCookieNoCookieError(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestProcessCookieRefreshNotSet(t *testing.T) { | ||||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| 	pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||
| 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||
| 		opts.CookieExpire = time.Duration(23) * time.Hour | ||||
| 	}) | ||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||
| 
 | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pcTest.SaveSession(startSession, reference) | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||
| 	pcTest.SaveSession(startSession) | ||||
| 
 | ||||
| 	session, age, err := pcTest.LoadCookiedSession() | ||||
| 	session, err := pcTest.LoadCookiedSession() | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	if age < time.Duration(-2)*time.Hour { | ||||
| 		t.Errorf("cookie too young %v", age) | ||||
| 	if session.Age() < time.Duration(-2)*time.Hour { | ||||
| 		t.Errorf("cookie too young %v", session.Age()) | ||||
| 	} | ||||
| 	assert.Equal(t, startSession.Email, session.Email) | ||||
| } | ||||
| 
 | ||||
| func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||
| 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	}) | ||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pcTest.SaveSession(startSession, reference) | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||
| 	pcTest.SaveSession(startSession) | ||||
| 
 | ||||
| 	session, _, err := pcTest.LoadCookiedSession() | ||||
| 	session, err := pcTest.LoadCookiedSession() | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expected nil session %#v", session) | ||||
|  | @ -707,22 +717,23 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | ||||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||
| 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	}) | ||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pcTest.SaveSession(startSession, reference) | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||
| 	pcTest.SaveSession(startSession) | ||||
| 
 | ||||
| 	pcTest.proxy.CookieRefresh = time.Hour | ||||
| 	session, _, err := pcTest.LoadCookiedSession() | ||||
| 	session, err := pcTest.LoadCookiedSession() | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	if session != nil { | ||||
| 		t.Errorf("expected nil session %#v", session) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func NewAuthOnlyEndpointTest() *ProcessCookieTest { | ||||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest { | ||||
| 	pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...) | ||||
| 	pcTest.req, _ = http.NewRequest("GET", | ||||
| 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||
| 	return pcTest | ||||
|  | @ -731,8 +742,8 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { | |||
| func TestAuthOnlyEndpointAccepted(t *testing.T) { | ||||
| 	test := NewAuthOnlyEndpointTest() | ||||
| 	startSession := &sessions.SessionState{ | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	test.SaveSession(startSession, time.Now()) | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||
| 	test.SaveSession(startSession) | ||||
| 
 | ||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | ||||
| 	assert.Equal(t, http.StatusAccepted, test.rw.Code) | ||||
|  | @ -750,12 +761,13 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | ||||
| 	test := NewAuthOnlyEndpointTest() | ||||
| 	test.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	test := NewAuthOnlyEndpointTest(func(opts *Options) { | ||||
| 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	}) | ||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||
| 	startSession := &sessions.SessionState{ | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	test.SaveSession(startSession, reference) | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||
| 	test.SaveSession(startSession) | ||||
| 
 | ||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | ||||
| 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code) | ||||
|  | @ -766,8 +778,8 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | |||
| func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | ||||
| 	test := NewAuthOnlyEndpointTest() | ||||
| 	startSession := &sessions.SessionState{ | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	test.SaveSession(startSession, time.Now()) | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||
| 	test.SaveSession(startSession) | ||||
| 	test.validateUser = false | ||||
| 
 | ||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | ||||
|  | @ -797,8 +809,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | |||
| 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||
| 
 | ||||
| 	startSession := &sessions.SessionState{ | ||||
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} | ||||
| 	pcTest.SaveSession(startSession, time.Now()) | ||||
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()} | ||||
| 	pcTest.SaveSession(startSession) | ||||
| 
 | ||||
| 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) | ||||
| 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) | ||||
|  | @ -930,11 +942,11 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { | |||
| 
 | ||||
| 	state := &sessions.SessionState{ | ||||
| 		Email: "mbland@acm.org", AccessToken: "my_access_token"} | ||||
| 	value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) | ||||
| 	err = proxy.SaveSession(st.rw, req, state) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) { | ||||
| 	for _, c := range st.rw.Result().Cookies() { | ||||
| 		req.AddCookie(c) | ||||
| 	} | ||||
| 	// This is used by the upstream to validate the signature.
 | ||||
|  | @ -1068,7 +1080,12 @@ func TestAjaxForbiddendRequest(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestClearSplitCookie(t *testing.T) { | ||||
| 	p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} | ||||
| 	opts := NewOptions() | ||||
| 	opts.CookieName = "oauth2" | ||||
| 	opts.CookieDomain = "abc" | ||||
| 	store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions) | ||||
| 	assert.Equal(t, err, nil) | ||||
| 	p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} | ||||
| 	var rw = httptest.NewRecorder() | ||||
| 	req := httptest.NewRequest("get", "/", nil) | ||||
| 
 | ||||
|  | @ -1092,7 +1109,12 @@ func TestClearSplitCookie(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestClearSingleCookie(t *testing.T) { | ||||
| 	p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} | ||||
| 	opts := NewOptions() | ||||
| 	opts.CookieName = "oauth2" | ||||
| 	opts.CookieDomain = "abc" | ||||
| 	store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions) | ||||
| 	assert.Equal(t, err, nil) | ||||
| 	p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} | ||||
| 	var rw = httptest.NewRecorder() | ||||
| 	req := httptest.NewRequest("get", "/", nil) | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										24
									
								
								options.go
								
								
								
								
							
							
						
						
									
										24
									
								
								options.go
								
								
								
								
							|  | @ -17,8 +17,11 @@ import ( | |||
| 	oidc "github.com/coreos/go-oidc" | ||||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 	"github.com/mbland/hmacauth" | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||
| 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/providers" | ||||
| 	"gopkg.in/natefinch/lumberjack.v2" | ||||
| ) | ||||
|  | @ -111,6 +114,7 @@ type Options struct { | |||
| 	proxyURLs     []*url.URL | ||||
| 	CompiledRegex []*regexp.Regexp | ||||
| 	provider      providers.Provider | ||||
| 	sessionStore  sessionsapi.SessionStore | ||||
| 	signatureData *SignatureData | ||||
| 	oidcVerifier  *oidc.IDTokenVerifier | ||||
| } | ||||
|  | @ -136,6 +140,9 @@ func NewOptions() *Options { | |||
| 			CookieExpire:   time.Duration(168) * time.Hour, | ||||
| 			CookieRefresh:  time.Duration(0), | ||||
| 		}, | ||||
| 		SessionOptions: options.SessionOptions{ | ||||
| 			Type: "cookie", | ||||
| 		}, | ||||
| 		SetXAuthRequest:       false, | ||||
| 		SkipAuthPreflight:     false, | ||||
| 		PassBasicAuth:         true, | ||||
|  | @ -261,7 +268,8 @@ func (o *Options) Validate() error { | |||
| 	} | ||||
| 	msgs = parseProviderInfo(o, msgs) | ||||
| 
 | ||||
| 	if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { | ||||
| 	var cipher *cookie.Cipher | ||||
| 	if o.PassAccessToken || o.SetAuthorization || o.PassAuthorization || (o.CookieRefresh != time.Duration(0)) { | ||||
| 		validCookieSecretSize := false | ||||
| 		for _, i := range []int{16, 24, 32} { | ||||
| 			if len(secretBytes(o.CookieSecret)) == i { | ||||
|  | @ -283,9 +291,23 @@ func (o *Options) Validate() error { | |||
| 					"pass_access_token == true or "+ | ||||
| 					"cookie_refresh != 0, but is %d bytes.%s", | ||||
| 				len(secretBytes(o.CookieSecret)), suffix)) | ||||
| 		} else { | ||||
| 			var err error | ||||
| 			cipher, err = cookie.NewCipher(secretBytes(o.CookieSecret)) | ||||
| 			if err != nil { | ||||
| 				msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	o.SessionOptions.Cipher = cipher | ||||
| 	sessionStore, err := sessions.NewSessionStore(&o.SessionOptions, &o.CookieOptions) | ||||
| 	if err != nil { | ||||
| 		msgs = append(msgs, fmt.Sprintf("error initialising session storage: %v", err)) | ||||
| 	} else { | ||||
| 		o.sessionStore = sessionStore | ||||
| 	} | ||||
| 
 | ||||
| 	if o.CookieRefresh >= o.CookieExpire { | ||||
| 		msgs = append(msgs, fmt.Sprintf( | ||||
| 			"cookie_refresh (%s) must be less than "+ | ||||
|  |  | |||
|  | @ -1,8 +1,13 @@ | |||
| package options | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| ) | ||||
| 
 | ||||
| // SessionOptions contains configuration options for the SessionStore providers.
 | ||||
| type SessionOptions struct { | ||||
| 	Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` | ||||
| 	Type   string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` | ||||
| 	Cipher *cookie.Cipher | ||||
| 	CookieStoreOptions | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ import ( | |||
| type SessionState struct { | ||||
| 	AccessToken  string    `json:",omitempty"` | ||||
| 	IDToken      string    `json:",omitempty"` | ||||
| 	CreatedAt    time.Time `json:"-"` | ||||
| 	ExpiresOn    time.Time `json:"-"` | ||||
| 	RefreshToken string    `json:",omitempty"` | ||||
| 	Email        string    `json:",omitempty"` | ||||
|  | @ -23,6 +24,7 @@ type SessionState struct { | |||
| // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
 | ||||
| type SessionStateJSON struct { | ||||
| 	*SessionState | ||||
| 	CreatedAt *time.Time `json:",omitempty"` | ||||
| 	ExpiresOn *time.Time `json:",omitempty"` | ||||
| } | ||||
| 
 | ||||
|  | @ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool { | |||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // Age returns the age of a session
 | ||||
| func (s *SessionState) Age() time.Duration { | ||||
| 	if !s.CreatedAt.IsZero() { | ||||
| 		return time.Now().Truncate(time.Second).Sub(s.CreatedAt) | ||||
| 	} | ||||
| 	return 0 | ||||
| } | ||||
| 
 | ||||
| // String constructs a summary of the session state
 | ||||
| func (s *SessionState) String() string { | ||||
| 	o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) | ||||
|  | @ -43,6 +53,9 @@ func (s *SessionState) String() string { | |||
| 	if s.IDToken != "" { | ||||
| 		o += " id_token:true" | ||||
| 	} | ||||
| 	if !s.CreatedAt.IsZero() { | ||||
| 		o += fmt.Sprintf(" created:%s", s.CreatedAt) | ||||
| 	} | ||||
| 	if !s.ExpiresOn.IsZero() { | ||||
| 		o += fmt.Sprintf(" expires:%s", s.ExpiresOn) | ||||
| 	} | ||||
|  | @ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { | |||
| 	} | ||||
| 	// Embed SessionState and ExpiresOn pointer into SessionStateJSON
 | ||||
| 	ssj := &SessionStateJSON{SessionState: &ss} | ||||
| 	if !ss.CreatedAt.IsZero() { | ||||
| 		ssj.CreatedAt = &ss.CreatedAt | ||||
| 	} | ||||
| 	if !ss.ExpiresOn.IsZero() { | ||||
| 		ssj.ExpiresOn = &ss.ExpiresOn | ||||
| 	} | ||||
|  | @ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { | |||
| 	var ss *SessionState | ||||
| 	err := json.Unmarshal([]byte(v), &ssj) | ||||
| 	if err == nil && ssj.SessionState != nil { | ||||
| 		// Extract SessionState and ExpiresOn value from SessionStateJSON
 | ||||
| 		// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
 | ||||
| 		ss = ssj.SessionState | ||||
| 		if ssj.CreatedAt != nil { | ||||
| 			ss.CreatedAt = *ssj.CreatedAt | ||||
| 		} | ||||
| 		if ssj.ExpiresOn != nil { | ||||
| 			ss.ExpiresOn = *ssj.ExpiresOn | ||||
| 		} | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		IDToken:      "rawtoken1234", | ||||
| 		CreatedAt:    time.Now(), | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||
| 		RefreshToken: "refresh4321", | ||||
| 	} | ||||
|  | @ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, s.AccessToken, ss.AccessToken) | ||||
| 	assert.Equal(t, s.IDToken, ss.IDToken) | ||||
| 	assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) | ||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||
| 
 | ||||
|  | @ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.NotEqual(t, "user@domain.com", ss.User) | ||||
| 	assert.NotEqual(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) | ||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||
| 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | ||||
| 	assert.NotEqual(t, s.IDToken, ss.IDToken) | ||||
|  | @ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | |||
| 		User:         "just-user", | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		CreatedAt:    time.Now(), | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||
| 		RefreshToken: "refresh4321", | ||||
| 	} | ||||
|  | @ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | |||
| 	assert.Equal(t, s.User, ss.User) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, s.AccessToken, ss.AccessToken) | ||||
| 	assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) | ||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||
| 
 | ||||
|  | @ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | |||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.NotEqual(t, s.User, ss.User) | ||||
| 	assert.NotEqual(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) | ||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||
| 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | ||||
| 	assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) | ||||
|  | @ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | |||
| 	s := &sessions.SessionState{ | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		CreatedAt:    time.Now(), | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||
| 		RefreshToken: "refresh4321", | ||||
| 	} | ||||
|  | @ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | |||
| 		User:         "just-user", | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		CreatedAt:    time.Now(), | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||
| 		RefreshToken: "refresh4321", | ||||
| 	} | ||||
|  | @ -147,6 +155,7 @@ type testCase struct { | |||
| // Currently only tests without cipher here because we have no way to mock
 | ||||
| // the random generator used in EncodeSessionState.
 | ||||
| func TestEncodeSessionState(t *testing.T) { | ||||
| 	c := time.Now() | ||||
| 	e := time.Now().Add(time.Duration(1) * time.Hour) | ||||
| 
 | ||||
| 	testCases := []testCase{ | ||||
|  | @ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) { | |||
| 				User:         "just-user", | ||||
| 				AccessToken:  "token1234", | ||||
| 				IDToken:      "rawtoken1234", | ||||
| 				CreatedAt:    c, | ||||
| 				ExpiresOn:    e, | ||||
| 				RefreshToken: "refresh4321", | ||||
| 			}, | ||||
|  | @ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) { | |||
| 
 | ||||
| // TestDecodeSessionState testssessions.DecodeSessionState with the test vector
 | ||||
| func TestDecodeSessionState(t *testing.T) { | ||||
| 	created := time.Now() | ||||
| 	createdJSON, _ := created.MarshalJSON() | ||||
| 	createdString := string(createdJSON) | ||||
| 	e := time.Now().Add(time.Duration(1) * time.Hour) | ||||
| 	eJSON, _ := e.MarshalJSON() | ||||
| 	eString := string(eJSON) | ||||
|  | @ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
| 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), | ||||
| 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
|  | @ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 				User:         "just-user", | ||||
| 				AccessToken:  "token1234", | ||||
| 				IDToken:      "rawtoken1234", | ||||
| 				CreatedAt:    created, | ||||
| 				ExpiresOn:    e, | ||||
| 				RefreshToken: "refresh4321", | ||||
| 			}, | ||||
| 			Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), | ||||
| 			Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), | ||||
| 			Cipher:  c, | ||||
| 		}, | ||||
| 		{ | ||||
|  | @ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSessionStateAge(t *testing.T) { | ||||
| 	ss := &sessions.SessionState{} | ||||
| 
 | ||||
| 	// Created at unset so should be 0
 | ||||
| 	assert.Equal(t, time.Duration(0), ss.Age()) | ||||
| 
 | ||||
| 	// Set CreatedAt to 1 hour ago
 | ||||
| 	ss.CreatedAt = time.Now().Add(-1 * time.Hour) | ||||
| 	assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) | ||||
| } | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||
| ) | ||||
| 
 | ||||
| // MakeCookie constructs a cookie from the given parameters,
 | ||||
|  | @ -32,3 +33,9 @@ func MakeCookie(req *http.Request, name string, value string, path string, domai | |||
| 		Expires:  now.Add(expiration), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // MakeCookieFromOptions constructs a cookie based on the givemn *options.CookieOptions,
 | ||||
| // value and creation time
 | ||||
| func MakeCookieFromOptions(req *http.Request, name string, value string, opts *options.CookieOptions, expiration time.Duration, now time.Time) *http.Cookie { | ||||
| 	return MakeCookie(req, name, value, opts.CookiePath, opts.CookieDomain, opts.CookieHTTPOnly, opts.CookieSecure, expiration, now) | ||||
| } | ||||
|  |  | |||
|  | @ -27,36 +27,33 @@ var _ sessions.SessionStore = &SessionStore{} | |||
| // SessionStore is an implementation of the sessions.SessionStore
 | ||||
| // interface that stores sessions in client side cookies
 | ||||
| type SessionStore struct { | ||||
| 	CookieCipher   *cookie.Cipher | ||||
| 	CookieDomain   string | ||||
| 	CookieExpire   time.Duration | ||||
| 	CookieHTTPOnly bool | ||||
| 	CookieName     string | ||||
| 	CookiePath     string | ||||
| 	CookieSecret   string | ||||
| 	CookieSecure   bool | ||||
| 	CookieOptions *options.CookieOptions | ||||
| 	CookieCipher  *cookie.Cipher | ||||
| } | ||||
| 
 | ||||
| // Save takes a sessions.SessionState and stores the information from it
 | ||||
| // within Cookies set on the HTTP response writer
 | ||||
| func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | ||||
| 	if ss.CreatedAt.IsZero() { | ||||
| 		ss.CreatedAt = time.Now() | ||||
| 	} | ||||
| 	value, err := utils.CookieForSession(ss, s.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	s.setSessionCookie(rw, req, value) | ||||
| 	s.setSessionCookie(rw, req, value, ss.CreatedAt) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Load reads sessions.SessionState information from Cookies within the
 | ||||
| // HTTP request object
 | ||||
| func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { | ||||
| 	c, err := loadCookie(req, s.CookieName) | ||||
| 	c, err := loadCookie(req, s.CookieOptions.CookieName) | ||||
| 	if err != nil { | ||||
| 		// always http.ErrNoCookie
 | ||||
| 		return nil, fmt.Errorf("Cookie %q not present", s.CookieName) | ||||
| 		return nil, fmt.Errorf("Cookie %q not present", s.CookieOptions.CookieName) | ||||
| 	} | ||||
| 	val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire) | ||||
| 	val, _, ok := cookie.Validate(c, s.CookieOptions.CookieSecret, s.CookieOptions.CookieExpire) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Cookie Signature not valid") | ||||
| 	} | ||||
|  | @ -74,11 +71,11 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | |||
| 	var cookies []*http.Cookie | ||||
| 
 | ||||
| 	// matches CookieName, CookieName_<number>
 | ||||
| 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName)) | ||||
| 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieOptions.CookieName)) | ||||
| 
 | ||||
| 	for _, c := range req.Cookies() { | ||||
| 		if cookieNameRegex.MatchString(c.Name) { | ||||
| 			clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1) | ||||
| 			clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) | ||||
| 
 | ||||
| 			http.SetCookie(rw, clearCookie) | ||||
| 			cookies = append(cookies, clearCookie) | ||||
|  | @ -89,60 +86,42 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | |||
| } | ||||
| 
 | ||||
| // setSessionCookie adds the user's session cookie to the response
 | ||||
| func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | ||||
| 	for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { | ||||
| func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) { | ||||
| 	for _, c := range s.makeSessionCookie(req, val, created) { | ||||
| 		http.SetCookie(rw, c) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // makeSessionCookie creates an http.Cookie containing the authenticated user's
 | ||||
| // authentication details
 | ||||
| func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { | ||||
| func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie { | ||||
| 	if value != "" { | ||||
| 		value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now) | ||||
| 		value = cookie.SignedValue(s.CookieOptions.CookieSecret, s.CookieOptions.CookieName, value, now) | ||||
| 	} | ||||
| 	c := s.makeCookie(req, s.CookieName, value, expiration) | ||||
| 	if len(c.Value) > 4096-len(s.CookieName) { | ||||
| 	c := s.makeCookie(req, s.CookieOptions.CookieName, value, s.CookieOptions.CookieExpire, now) | ||||
| 	if len(c.Value) > 4096-len(s.CookieOptions.CookieName) { | ||||
| 		return splitCookie(c) | ||||
| 	} | ||||
| 	return []*http.Cookie{c} | ||||
| } | ||||
| 
 | ||||
| func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie { | ||||
| 	return cookies.MakeCookie( | ||||
| func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||
| 	return cookies.MakeCookieFromOptions( | ||||
| 		req, | ||||
| 		name, | ||||
| 		value, | ||||
| 		s.CookiePath, | ||||
| 		s.CookieDomain, | ||||
| 		s.CookieHTTPOnly, | ||||
| 		s.CookieSecure, | ||||
| 		s.CookieOptions, | ||||
| 		expiration, | ||||
| 		time.Now(), | ||||
| 		now, | ||||
| 	) | ||||
| } | ||||
| 
 | ||||
| // NewCookieSessionStore initialises a new instance of the SessionStore from
 | ||||
| // the configuration given
 | ||||
| func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||
| 	var cipher *cookie.Cipher | ||||
| 	if len(cookieOpts.CookieSecret) > 0 { | ||||
| 		var err error | ||||
| 		cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("unable to create cipher: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||
| 	return &SessionStore{ | ||||
| 		CookieCipher:   cipher, | ||||
| 		CookieDomain:   cookieOpts.CookieDomain, | ||||
| 		CookieExpire:   cookieOpts.CookieExpire, | ||||
| 		CookieHTTPOnly: cookieOpts.CookieHTTPOnly, | ||||
| 		CookieName:     cookieOpts.CookieName, | ||||
| 		CookiePath:     cookieOpts.CookiePath, | ||||
| 		CookieSecret:   cookieOpts.CookieSecret, | ||||
| 		CookieSecure:   cookieOpts.CookieSecure, | ||||
| 		CookieCipher:  opts.Cipher, | ||||
| 		CookieOptions: cookieOpts, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,7 +12,7 @@ import ( | |||
| func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||
| 	switch opts.Type { | ||||
| 	case options.CookieSessionStoreType: | ||||
| 		return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) | ||||
| 		return cookie.NewCookieSessionStore(opts, cookieOpts) | ||||
| 	default: | ||||
| 		return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) | ||||
| 	} | ||||
|  |  | |||
|  | @ -5,16 +5,20 @@ import ( | |||
| 	"encoding/base64" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||
| 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/cookies" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||
| 	sessionscookie "github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions/utils" | ||||
| ) | ||||
| 
 | ||||
| func TestSessionStore(t *testing.T) { | ||||
|  | @ -72,6 +76,16 @@ var _ = Describe("NewSessionStore", func() { | |||
| 				} | ||||
| 			}) | ||||
| 
 | ||||
| 			It("have a signature timestamp matching session.CreatedAt", func() { | ||||
| 				for _, cookie := range cookies { | ||||
| 					if cookie.Value != "" { | ||||
| 						parts := strings.Split(cookie.Value, "|") | ||||
| 						Expect(parts).To(HaveLen(3)) | ||||
| 						Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix())))) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
| 
 | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
|  | @ -86,6 +100,10 @@ var _ = Describe("NewSessionStore", func() { | |||
| 				Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Ensures the session CreatedAt is not zero", func() { | ||||
| 				Expect(session.CreatedAt.IsZero()).To(BeFalse()) | ||||
| 			}) | ||||
| 
 | ||||
| 			CheckCookieOptions() | ||||
| 		}) | ||||
| 
 | ||||
|  | @ -138,12 +156,15 @@ var _ = Describe("NewSessionStore", func() { | |||
| 
 | ||||
| 					// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
 | ||||
| 					l := *loadedSession | ||||
| 					l.CreatedAt = time.Time{} | ||||
| 					l.ExpiresOn = time.Time{} | ||||
| 					s := *session | ||||
| 					s.CreatedAt = time.Time{} | ||||
| 					s.ExpiresOn = time.Time{} | ||||
| 					Expect(l).To(Equal(s)) | ||||
| 
 | ||||
| 					// Compare time.Time separately
 | ||||
| 					Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue()) | ||||
| 					Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) | ||||
| 				} | ||||
| 			}) | ||||
|  | @ -181,12 +202,16 @@ var _ = Describe("NewSessionStore", func() { | |||
| 			SessionStoreInterfaceTests() | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with a cookie-secret set", func() { | ||||
| 		Context("with a cipher", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				secret := make([]byte, 32) | ||||
| 				_, err := rand.Read(secret) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) | ||||
| 				cipher, err := cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(cipher).ToNot(BeNil()) | ||||
| 				opts.Cipher = cipher | ||||
| 
 | ||||
| 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
|  | @ -231,7 +256,7 @@ var _ = Describe("NewSessionStore", func() { | |||
| 		It("creates a cookie.SessionStore", func() { | ||||
| 			ss, err := sessions.NewSessionStore(opts, cookieOpts) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 			Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{})) | ||||
| 			Expect(ss).To(BeAssignableToTypeOf(&sessionscookie.SessionStore{})) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("the cookie.SessionStore", func() { | ||||
|  |  | |||
|  | @ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt | |||
| 	s = &sessions.SessionState{ | ||||
| 		AccessToken:  jsonResponse.AccessToken, | ||||
| 		IDToken:      jsonResponse.IDToken, | ||||
| 		CreatedAt:    time.Now(), | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||
| 		RefreshToken: jsonResponse.RefreshToken, | ||||
| 		Email:        c.Email, | ||||
|  |  | |||
|  | @ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session | |||
| 	s = &sessions.SessionState{ | ||||
| 		AccessToken: jsonResponse.AccessToken, | ||||
| 		IDToken:     jsonResponse.IDToken, | ||||
| 		CreatedAt:   time.Now(), | ||||
| 		ExpiresOn:   time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||
| 		Email:       email, | ||||
| 	} | ||||
|  |  | |||
|  | @ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) | |||
| 	s.AccessToken = newSession.AccessToken | ||||
| 	s.IDToken = newSession.IDToken | ||||
| 	s.RefreshToken = newSession.RefreshToken | ||||
| 	s.CreatedAt = newSession.CreatedAt | ||||
| 	s.ExpiresOn = newSession.ExpiresOn | ||||
| 	s.Email = newSession.Email | ||||
| 	return | ||||
|  | @ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | |||
| 		AccessToken:  token.AccessToken, | ||||
| 		IDToken:      rawIDToken, | ||||
| 		RefreshToken: token.RefreshToken, | ||||
| 		CreatedAt:    time.Now(), | ||||
| 		ExpiresOn:    token.Expiry, | ||||
| 		Email:        claims.Email, | ||||
| 		User:         claims.Subject, | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ import ( | |||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
|  | @ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat | |||
| 		return | ||||
| 	} | ||||
| 	if a := v.Get("access_token"); a != "" { | ||||
| 		s = &sessions.SessionState{AccessToken: a} | ||||
| 		s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()} | ||||
| 	} else { | ||||
| 		err = fmt.Errorf("no access token found %s", body) | ||||
| 	} | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue