Use SessionStore for session in proxy
This commit is contained in:
		
							parent
							
								
									34cbe0497c
								
							
						
					
					
						commit
						c61f3a1c65
					
				|  | @ -456,27 +456,8 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va | ||||||
| 
 | 
 | ||||||
| // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | ||||||
| // stored in the user's session
 | // stored in the user's session
 | ||||||
| func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { | ||||||
| 	var cookies []*http.Cookie | 	return p.sessionStore.Clear(rw, req) | ||||||
| 
 |  | ||||||
| 	// matches CookieName, CookieName_<number>
 |  | ||||||
| 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName)) |  | ||||||
| 
 |  | ||||||
| 	for _, c := range req.Cookies() { |  | ||||||
| 		if cookieNameRegex.MatchString(c.Name) { |  | ||||||
| 			clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) |  | ||||||
| 
 |  | ||||||
| 			http.SetCookie(rw, clearCookie) |  | ||||||
| 			cookies = append(cookies, clearCookie) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// ugly hack because default domain changed
 |  | ||||||
| 	if p.CookieDomain == "" && len(cookies) > 0 { |  | ||||||
| 		clr2 := *cookies[0] |  | ||||||
| 		clr2.Domain = req.Host |  | ||||||
| 		http.SetCookie(rw, &clr2) |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SetSessionCookie adds the user's session cookie to the response
 | // SetSessionCookie adds the user's session cookie to the response
 | ||||||
|  | @ -487,35 +468,13 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // LoadCookiedSession reads the user's authentication details from the request
 | // LoadCookiedSession reads the user's authentication details from the request
 | ||||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, time.Duration, error) { | func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
| 	var age time.Duration | 	return p.sessionStore.Load(req) | ||||||
| 	c, err := loadCookie(req, p.CookieName) |  | ||||||
| 	if err != nil { |  | ||||||
| 		// always http.ErrNoCookie
 |  | ||||||
| 		return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) |  | ||||||
| 	} |  | ||||||
| 	val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil, age, errors.New("Cookie Signature not valid") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	session, err := p.provider.SessionFromCookie(val, p.CookieCipher) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, age, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	age = time.Now().Truncate(time.Second).Sub(timestamp) |  | ||||||
| 	return session, age, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SaveSession creates a new session cookie value and sets this on the response
 | // SaveSession creates a new session cookie value and sets this on the response
 | ||||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { | func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { | ||||||
| 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | 	return p.sessionStore.Save(rw, req, s) | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	p.SetSessionCookie(rw, req, value) |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RobotsTxt disallows scraping pages from the OAuthProxy
 | // RobotsTxt disallows scraping pages from the OAuthProxy
 | ||||||
|  | @ -835,12 +794,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 	var saveSession, clearSession, revalidated bool | 	var saveSession, clearSession, revalidated bool | ||||||
| 	remoteAddr := getRemoteAddr(req) | 	remoteAddr := getRemoteAddr(req) | ||||||
| 
 | 
 | ||||||
| 	session, sessionAge, err := p.LoadCookiedSession(req) | 	session, err := p.LoadCookiedSession(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("Error loading cookied session: %s", err) | 		logger.Printf("Error loading cookied session: %s", err) | ||||||
| 	} | 	} | ||||||
| 	if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { | 	if session != nil && session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { | ||||||
| 		logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", sessionAge, session, p.CookieRefresh) | 		logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh) | ||||||
| 		saveSession = true | 		saveSession = true | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ import ( | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
|  | @ -600,10 +601,15 @@ type ProcessCookieTestOpts struct { | ||||||
| 	providerValidateCookieResponse bool | 	providerValidateCookieResponse bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { | type OptionsModifier func(*Options) | ||||||
|  | 
 | ||||||
|  | func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest { | ||||||
| 	var pcTest ProcessCookieTest | 	var pcTest ProcessCookieTest | ||||||
| 
 | 
 | ||||||
| 	pcTest.opts = NewOptions() | 	pcTest.opts = NewOptions() | ||||||
|  | 	for _, modifier := range modifiers { | ||||||
|  | 		modifier(pcTest.opts) | ||||||
|  | 	} | ||||||
| 	pcTest.opts.ClientID = "bazquux" | 	pcTest.opts.ClientID = "bazquux" | ||||||
| 	pcTest.opts.ClientSecret = "xyzzyplugh" | 	pcTest.opts.ClientSecret = "xyzzyplugh" | ||||||
| 	pcTest.opts.CookieSecret = "0123456789abcdefabcd" | 	pcTest.opts.CookieSecret = "0123456789abcdefabcd" | ||||||
|  | @ -634,32 +640,38 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest { | ||||||
|  | 	return NewProcessCookieTest(ProcessCookieTestOpts{ | ||||||
|  | 		providerValidateCookieResponse: true, | ||||||
|  | 	}, modifiers...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { | func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { | ||||||
| 	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) | 	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { | func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error { | ||||||
| 	value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) | 	err := p.proxy.SaveSession(p.rw, p.req, s) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { | 	for _, cookie := range p.rw.Result().Cookies() { | ||||||
| 		p.req.AddCookie(c) | 		p.req.AddCookie(cookie) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { | func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) { | ||||||
| 	return p.proxy.LoadCookiedSession(p.req) | 	return p.proxy.LoadCookiedSession(p.req) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestLoadCookiedSession(t *testing.T) { | func TestLoadCookiedSession(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 
 | 
 | ||||||
| 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||||
| 	pcTest.SaveSession(startSession, time.Now()) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, startSession.Email, session.Email) | 	assert.Equal(t, startSession.Email, session.Email) | ||||||
| 	assert.Equal(t, "john.doe@example.com", session.User) | 	assert.Equal(t, "john.doe@example.com", session.User) | ||||||
|  | @ -669,7 +681,7 @@ func TestLoadCookiedSession(t *testing.T) { | ||||||
| func TestProcessCookieNoCookieError(t *testing.T) { | func TestProcessCookieNoCookieError(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) | 	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session. got %#v", session) | 		t.Errorf("expected nil session. got %#v", session) | ||||||
|  | @ -677,29 +689,31 @@ func TestProcessCookieNoCookieError(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieRefreshNotSet(t *testing.T) { | func TestProcessCookieRefreshNotSet(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour | 		opts.CookieExpire = time.Duration(23) * time.Hour | ||||||
|  | 	}) | ||||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||||
| 
 | 
 | ||||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	session, age, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	if age < time.Duration(-2)*time.Hour { | 	if session.Age() < time.Duration(-2)*time.Hour { | ||||||
| 		t.Errorf("cookie too young %v", age) | 		t.Errorf("cookie too young %v", session.Age()) | ||||||
| 	} | 	} | ||||||
| 	assert.Equal(t, startSession.Email, session.Email) | 	assert.Equal(t, startSession.Email, session.Email) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieFailIfCookieExpired(t *testing.T) { | func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||||
|  | 	}) | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session %#v", session) | 		t.Errorf("expected nil session %#v", session) | ||||||
|  | @ -707,22 +721,23 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||||
|  | 	}) | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	pcTest.proxy.CookieRefresh = time.Hour | 	pcTest.proxy.CookieRefresh = time.Hour | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session %#v", session) | 		t.Errorf("expected nil session %#v", session) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewAuthOnlyEndpointTest() *ProcessCookieTest { | func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...) | ||||||
| 	pcTest.req, _ = http.NewRequest("GET", | 	pcTest.req, _ = http.NewRequest("GET", | ||||||
| 		pcTest.opts.ProxyPrefix+"/auth", nil) | 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||||
| 	return pcTest | 	return pcTest | ||||||
|  | @ -731,8 +746,8 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { | ||||||
| func TestAuthOnlyEndpointAccepted(t *testing.T) { | func TestAuthOnlyEndpointAccepted(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest() | ||||||
| 	startSession := &sessions.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||||
| 	test.SaveSession(startSession, time.Now()) | 	test.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
| 	assert.Equal(t, http.StatusAccepted, test.rw.Code) | 	assert.Equal(t, http.StatusAccepted, test.rw.Code) | ||||||
|  | @ -750,12 +765,13 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest(func(opts *Options) { | ||||||
| 	test.proxy.CookieExpire = time.Duration(24) * time.Hour | 		opts.CookieExpire = time.Duration(24) * time.Hour | ||||||
|  | 	}) | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &sessions.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} | ||||||
| 	test.SaveSession(startSession, reference) | 	test.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
| 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code) | 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code) | ||||||
|  | @ -766,8 +782,8 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | ||||||
| func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest() | ||||||
| 	startSession := &sessions.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} | ||||||
| 	test.SaveSession(startSession, time.Now()) | 	test.SaveSession(startSession) | ||||||
| 	test.validateUser = false | 	test.validateUser = false | ||||||
| 
 | 
 | ||||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
|  | @ -797,8 +813,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | ||||||
| 		pcTest.opts.ProxyPrefix+"/auth", nil) | 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||||
| 
 | 
 | ||||||
| 	startSession := &sessions.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} | 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()} | ||||||
| 	pcTest.SaveSession(startSession, time.Now()) | 	pcTest.SaveSession(startSession) | ||||||
| 
 | 
 | ||||||
| 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) | 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) | ||||||
| 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) | 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) | ||||||
|  | @ -1068,7 +1084,12 @@ func TestAjaxForbiddendRequest(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestClearSplitCookie(t *testing.T) { | func TestClearSplitCookie(t *testing.T) { | ||||||
| 	p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} | 	opts := NewOptions() | ||||||
|  | 	opts.CookieName = "oauth2" | ||||||
|  | 	opts.CookieDomain = "abc" | ||||||
|  | 	store, err := cookie.NewCookieSessionStore(opts.SessionOptions.CookieStoreOptions, &opts.CookieOptions) | ||||||
|  | 	assert.Equal(t, err, nil) | ||||||
|  | 	p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} | ||||||
| 	var rw = httptest.NewRecorder() | 	var rw = httptest.NewRecorder() | ||||||
| 	req := httptest.NewRequest("get", "/", nil) | 	req := httptest.NewRequest("get", "/", nil) | ||||||
| 
 | 
 | ||||||
|  | @ -1092,7 +1113,12 @@ func TestClearSplitCookie(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestClearSingleCookie(t *testing.T) { | func TestClearSingleCookie(t *testing.T) { | ||||||
| 	p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} | 	opts := NewOptions() | ||||||
|  | 	opts.CookieName = "oauth2" | ||||||
|  | 	opts.CookieDomain = "abc" | ||||||
|  | 	store, err := cookie.NewCookieSessionStore(opts.SessionOptions.CookieStoreOptions, &opts.CookieOptions) | ||||||
|  | 	assert.Equal(t, err, nil) | ||||||
|  | 	p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} | ||||||
| 	var rw = httptest.NewRecorder() | 	var rw = httptest.NewRecorder() | ||||||
| 	req := httptest.NewRequest("get", "/", nil) | 	req := httptest.NewRequest("get", "/", nil) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -40,11 +40,14 @@ type SessionStore struct { | ||||||
| // Save takes a sessions.SessionState and stores the information from it
 | // Save takes a sessions.SessionState and stores the information from it
 | ||||||
| // within Cookies set on the HTTP response writer
 | // within Cookies set on the HTTP response writer
 | ||||||
| func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | ||||||
|  | 	if ss.CreatedAt.IsZero() { | ||||||
|  | 		ss.CreatedAt = time.Now() | ||||||
|  | 	} | ||||||
| 	value, err := utils.CookieForSession(ss, s.CookieCipher) | 	value, err := utils.CookieForSession(ss, s.CookieCipher) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	s.setSessionCookie(rw, req, value) | 	s.setSessionCookie(rw, req, value, ss.CreatedAt) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -89,8 +92,8 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // setSessionCookie adds the user's session cookie to the response
 | // setSessionCookie adds the user's session cookie to the response
 | ||||||
| func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) { | ||||||
| 	for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { | 	for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, created) { | ||||||
| 		http.SetCookie(rw, c) | 		http.SetCookie(rw, c) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -5,6 +5,8 @@ import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | @ -72,6 +74,16 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 				} | 				} | ||||||
| 			}) | 			}) | ||||||
| 
 | 
 | ||||||
|  | 			It("have a signature timestamp matching session.CreatedAt", func() { | ||||||
|  | 				for _, cookie := range cookies { | ||||||
|  | 					if cookie.Value != "" { | ||||||
|  | 						parts := strings.Split(cookie.Value, "|") | ||||||
|  | 						Expect(parts).To(HaveLen(3)) | ||||||
|  | 						Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix())))) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -86,6 +98,10 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 				Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) | 				Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) | ||||||
| 			}) | 			}) | ||||||
| 
 | 
 | ||||||
|  | 			It("Ensures the session CreatedAt is not zero", func() { | ||||||
|  | 				Expect(session.CreatedAt.IsZero()).To(BeFalse()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
| 			CheckCookieOptions() | 			CheckCookieOptions() | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
|  | @ -138,12 +154,15 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 
 | 
 | ||||||
| 					// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
 | 					// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
 | ||||||
| 					l := *loadedSession | 					l := *loadedSession | ||||||
|  | 					l.CreatedAt = time.Time{} | ||||||
| 					l.ExpiresOn = time.Time{} | 					l.ExpiresOn = time.Time{} | ||||||
| 					s := *session | 					s := *session | ||||||
|  | 					s.CreatedAt = time.Time{} | ||||||
| 					s.ExpiresOn = time.Time{} | 					s.ExpiresOn = time.Time{} | ||||||
| 					Expect(l).To(Equal(s)) | 					Expect(l).To(Equal(s)) | ||||||
| 
 | 
 | ||||||
| 					// Compare time.Time separately
 | 					// Compare time.Time separately
 | ||||||
|  | 					Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue()) | ||||||
| 					Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) | 					Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) | ||||||
| 				} | 				} | ||||||
| 			}) | 			}) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue