Add session loader from session storage
This commit is contained in:
		
							parent
							
								
									7d6f2a3f45
								
							
						
					
					
						commit
						034f057b60
					
				|  | @ -0,0 +1,165 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/justinas/alice" | ||||||
|  | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // StoredSessionLoaderOptions cotnains all of the requirements to construct
 | ||||||
|  | // a stored session loader.
 | ||||||
|  | // All options must be provided.
 | ||||||
|  | type StoredSessionLoaderOptions struct { | ||||||
|  | 	// Session storage basckend
 | ||||||
|  | 	SessionStore sessionsapi.SessionStore | ||||||
|  | 
 | ||||||
|  | 	// How often should sessions be refreshed
 | ||||||
|  | 	RefreshPeriod time.Duration | ||||||
|  | 
 | ||||||
|  | 	// Provider based sesssion refreshing
 | ||||||
|  | 	RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||||
|  | 
 | ||||||
|  | 	// Provider based session validation.
 | ||||||
|  | 	// If the sesssion is older than `RefreshPeriod` but the provider doesn't
 | ||||||
|  | 	// refresh it, we must re-validate using this validation.
 | ||||||
|  | 	ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewStoredSessionLoader creates a new storedSessionLoader which loads
 | ||||||
|  | // sessions from the session store.
 | ||||||
|  | // If no session is found, the request will be passed to the nex handler.
 | ||||||
|  | // If a session was loader by a previous handler, it will not be replaced.
 | ||||||
|  | func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { | ||||||
|  | 	ss := &storedSessionLoader{ | ||||||
|  | 		store:                              opts.SessionStore, | ||||||
|  | 		refreshPeriod:                      opts.RefreshPeriod, | ||||||
|  | 		refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, | ||||||
|  | 		validateSessionState:               opts.ValidateSessionState, | ||||||
|  | 	} | ||||||
|  | 	return ss.loadSession | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // storedSessionLoader is responsible for loading sessions from cookie
 | ||||||
|  | // identified sessions in the session store.
 | ||||||
|  | type storedSessionLoader struct { | ||||||
|  | 	store                              sessionsapi.SessionStore | ||||||
|  | 	refreshPeriod                      time.Duration | ||||||
|  | 	refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||||
|  | 	validateSessionState               func(context.Context, *sessionsapi.SessionState) bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // loadSession attempts to load a session as identified by the request cookies.
 | ||||||
|  | // If no session is found, the request will be passed to the nex handler.
 | ||||||
|  | // If a session was loader by a previous handler, it will not be replaced.
 | ||||||
|  | func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { | ||||||
|  | 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 		scope := GetRequestScope(req) | ||||||
|  | 		// If scope is nil, this will panic.
 | ||||||
|  | 		// A scope should always be injected before this handler is called.
 | ||||||
|  | 		if scope.Session != nil { | ||||||
|  | 			// The session was already loaded, pass to the next handler
 | ||||||
|  | 			next.ServeHTTP(rw, req) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		session, err := s.getValidatedSession(rw, req) | ||||||
|  | 		if err != nil { | ||||||
|  | 			// In the case when there was an error loading the session,
 | ||||||
|  | 			// we should clear the session
 | ||||||
|  | 			logger.Printf("Error loading cookied session: %v, removing session", err) | ||||||
|  | 			s.store.Clear(rw, req) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Add the session to the scope if it was found
 | ||||||
|  | 		scope.Session = session | ||||||
|  | 		next.ServeHTTP(rw, req) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getValidatedSession is responsible for loading a session and making sure
 | ||||||
|  | // that is is valid.
 | ||||||
|  | func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 	session, err := s.store.Load(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if session == nil { | ||||||
|  | 		// No session was found in the storage, nothing more to do
 | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = s.refreshSessionIfNeeded(rw, req, session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return session, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // refreshSessionIfNeeded will attempt to refresh a session if the session
 | ||||||
|  | // is older than the refresh period.
 | ||||||
|  | // It is assumed that if the provider refreshes the session, the session is now
 | ||||||
|  | // valid.
 | ||||||
|  | // If the session requires refreshing but the provider does not refresh it,
 | ||||||
|  | // we must validate the session to ensure that the returned session is still
 | ||||||
|  | // valid.
 | ||||||
|  | func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 	if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { | ||||||
|  | 		// Refresh is disabled or the session is not old enough, do nothing
 | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) | ||||||
|  | 	refreshed, err := s.refreshSessionWithProvider(rw, req, session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !refreshed { | ||||||
|  | 		// Session wasn't refreshed, so make sure it's still valid
 | ||||||
|  | 		return s.validateSession(req.Context(), session) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // refreshSessionWithProvider attempts to refresh the sessinon with the provider
 | ||||||
|  | // and will save the session if it was updated.
 | ||||||
|  | func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { | ||||||
|  | 	refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, fmt.Errorf("error refreshing access token: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !refreshed { | ||||||
|  | 		return false, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Because the session was refreshed, make sure to save it
 | ||||||
|  | 	err = s.store.Save(rw, req, session) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | ||||||
|  | 		return false, fmt.Errorf("error saving session: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return true, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // validateSession checks whether the session has expired and performs
 | ||||||
|  | // provider validation on the session.
 | ||||||
|  | // An error implies the session is not longer valid.
 | ||||||
|  | func (s *storedSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error { | ||||||
|  | 	if session.IsExpired() { | ||||||
|  | 		return errors.New("session is expired") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !s.validateSessionState(ctx, session) { | ||||||
|  | 		return errors.New("session is invalid") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | @ -0,0 +1,524 @@ | ||||||
|  | package middleware | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||||
|  | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Stored Session Suite", func() { | ||||||
|  | 	const ( | ||||||
|  | 		refresh   = "Refresh" | ||||||
|  | 		noRefresh = "NoRefresh" | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	var ctx = context.Background() | ||||||
|  | 
 | ||||||
|  | 	Context("StoredSessionLoader", func() { | ||||||
|  | 		createdPast := time.Now().Add(-5 * time.Minute) | ||||||
|  | 		createdFuture := time.Now().Add(5 * time.Minute) | ||||||
|  | 
 | ||||||
|  | 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
|  | 			switch ss.RefreshToken { | ||||||
|  | 			case refresh: | ||||||
|  | 				ss.RefreshToken = "Refreshed" | ||||||
|  | 				return true, nil | ||||||
|  | 			case noRefresh: | ||||||
|  | 				return false, nil | ||||||
|  | 			default: | ||||||
|  | 				return false, errors.New("error refreshing session") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||||
|  | 			return ss.AccessToken != "Invalid" | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var defaultSessionStore = &fakeSessionStore{ | ||||||
|  | 			LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 				switch req.Header.Get("Cookie") { | ||||||
|  | 				case "_oauth2_proxy=NoRefreshSession": | ||||||
|  | 					return &sessionsapi.SessionState{ | ||||||
|  | 						RefreshToken: noRefresh, | ||||||
|  | 						CreatedAt:    &createdPast, | ||||||
|  | 						ExpiresOn:    &createdFuture, | ||||||
|  | 					}, nil | ||||||
|  | 				case "_oauth2_proxy=InvalidNoRefreshSession": | ||||||
|  | 					return &sessionsapi.SessionState{ | ||||||
|  | 						AccessToken:  "Invalid", | ||||||
|  | 						RefreshToken: noRefresh, | ||||||
|  | 						CreatedAt:    &createdPast, | ||||||
|  | 						ExpiresOn:    &createdFuture, | ||||||
|  | 					}, nil | ||||||
|  | 				case "_oauth2_proxy=ExpiredNoRefreshSession": | ||||||
|  | 					return &sessionsapi.SessionState{ | ||||||
|  | 						RefreshToken: noRefresh, | ||||||
|  | 						CreatedAt:    &createdPast, | ||||||
|  | 						ExpiresOn:    &createdPast, | ||||||
|  | 					}, nil | ||||||
|  | 				case "_oauth2_proxy=RefreshSession": | ||||||
|  | 					return &sessionsapi.SessionState{ | ||||||
|  | 						RefreshToken: refresh, | ||||||
|  | 						CreatedAt:    &createdPast, | ||||||
|  | 						ExpiresOn:    &createdFuture, | ||||||
|  | 					}, nil | ||||||
|  | 				case "_oauth2_proxy=RefreshError": | ||||||
|  | 					return &sessionsapi.SessionState{ | ||||||
|  | 						RefreshToken: "RefreshError", | ||||||
|  | 						CreatedAt:    &createdPast, | ||||||
|  | 						ExpiresOn:    &createdFuture, | ||||||
|  | 					}, nil | ||||||
|  | 				case "_oauth2_proxy=NonExistent": | ||||||
|  | 					return nil, fmt.Errorf("invalid cookie") | ||||||
|  | 				default: | ||||||
|  | 					return nil, nil | ||||||
|  | 				} | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		type storedSessionLoaderTableInput struct { | ||||||
|  | 			requestHeaders  http.Header | ||||||
|  | 			existingSession *sessionsapi.SessionState | ||||||
|  | 			expectedSession *sessionsapi.SessionState | ||||||
|  | 			store           sessionsapi.SessionStore | ||||||
|  | 			refreshPeriod   time.Duration | ||||||
|  | 			refreshSession  func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||||
|  | 			validateSession func(context.Context, *sessionsapi.SessionState) bool | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("when serving a request", | ||||||
|  | 			func(in storedSessionLoaderTableInput) { | ||||||
|  | 				scope := &middlewareapi.RequestScope{ | ||||||
|  | 					Session: in.existingSession, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// Set up the request with the request headesr and a request scope
 | ||||||
|  | 				req := httptest.NewRequest("", "/", nil) | ||||||
|  | 				req.Header = in.requestHeaders | ||||||
|  | 				contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) | ||||||
|  | 				req = req.WithContext(contextWithScope) | ||||||
|  | 
 | ||||||
|  | 				rw := httptest.NewRecorder() | ||||||
|  | 
 | ||||||
|  | 				opts := &StoredSessionLoaderOptions{ | ||||||
|  | 					SessionStore:           in.store, | ||||||
|  | 					RefreshPeriod:          in.refreshPeriod, | ||||||
|  | 					RefreshSessionIfNeeded: in.refreshSession, | ||||||
|  | 					ValidateSessionState:   in.validateSession, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// Create the handler with a next handler that will capture the session
 | ||||||
|  | 				// from the scope
 | ||||||
|  | 				var gotSession *sessionsapi.SessionState | ||||||
|  | 				handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 					gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session | ||||||
|  | 				})) | ||||||
|  | 				handler.ServeHTTP(rw, req) | ||||||
|  | 
 | ||||||
|  | 				Expect(gotSession).To(Equal(in.expectedSession)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("with no cookie", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders:  http.Header{}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: nil, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid cookie", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=NonExistent"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: nil, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an existing session", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: "Existing", | ||||||
|  | 				}, | ||||||
|  | 				expectedSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: "Existing", | ||||||
|  | 				}, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a session that has not expired", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=NoRefreshSession"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: noRefresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 					ExpiresOn:    &createdFuture, | ||||||
|  | 				}, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: nil, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 					ExpiresOn:    &createdFuture, | ||||||
|  | 				}, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   10 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: "Refreshed", | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 					ExpiresOn:    &createdFuture, | ||||||
|  | 				}, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the provider refresh fails", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=RefreshError"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: nil, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: nil, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("refreshSessionIfNeeded", func() { | ||||||
|  | 		type refreshSessionIfNeededTableInput struct { | ||||||
|  | 			refreshPeriod   time.Duration | ||||||
|  | 			session         *sessionsapi.SessionState | ||||||
|  | 			expectedErr     error | ||||||
|  | 			expectRefreshed bool | ||||||
|  | 			expectValidated bool | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		createdPast := time.Now().Add(-5 * time.Minute) | ||||||
|  | 		createdFuture := time.Now().Add(5 * time.Minute) | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("with a session", | ||||||
|  | 			func(in refreshSessionIfNeededTableInput) { | ||||||
|  | 				refreshed := false | ||||||
|  | 				validated := false | ||||||
|  | 
 | ||||||
|  | 				s := &storedSessionLoader{ | ||||||
|  | 					refreshPeriod: in.refreshPeriod, | ||||||
|  | 					store:         &fakeSessionStore{}, | ||||||
|  | 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
|  | 						refreshed = true | ||||||
|  | 						switch ss.RefreshToken { | ||||||
|  | 						case refresh: | ||||||
|  | 							return true, nil | ||||||
|  | 						case noRefresh: | ||||||
|  | 							return false, nil | ||||||
|  | 						default: | ||||||
|  | 							return false, errors.New("error refreshing session") | ||||||
|  | 						} | ||||||
|  | 					}, | ||||||
|  | 					validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||||
|  | 						validated = true | ||||||
|  | 						return ss.AccessToken != "Invalid" | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				req := httptest.NewRequest("", "/", nil) | ||||||
|  | 				err := s.refreshSessionIfNeeded(nil, req, in.session) | ||||||
|  | 				if in.expectedErr != nil { | ||||||
|  | 					Expect(err).To(MatchError(in.expectedErr)) | ||||||
|  | 				} else { | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				} | ||||||
|  | 				Expect(refreshed).To(Equal(in.expectRefreshed)) | ||||||
|  | 				Expect(validated).To(Equal(in.expectValidated)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: time.Duration(0), | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdFuture, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectValidated: false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: time.Duration(0), | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectValidated: false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdFuture, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectValidated: false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: true, | ||||||
|  | 				expectValidated: false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: noRefresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 					ExpiresOn:    &createdFuture, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: true, | ||||||
|  | 				expectValidated: true, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: "RefreshError", | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | ||||||
|  | 				expectRefreshed: true, | ||||||
|  | 				expectValidated: false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					AccessToken:  "Invalid", | ||||||
|  | 					RefreshToken: noRefresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 					ExpiresOn:    &createdFuture, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     errors.New("session is invalid"), | ||||||
|  | 				expectRefreshed: true, | ||||||
|  | 				expectValidated: true, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("refreshSessionWithProvider", func() { | ||||||
|  | 		type refreshSessionWithProviderTableInput struct { | ||||||
|  | 			session         *sessionsapi.SessionState | ||||||
|  | 			expectedErr     error | ||||||
|  | 			expectRefreshed bool | ||||||
|  | 			expectSaved     bool | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		now := time.Now() | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("when refreshing with the provider", | ||||||
|  | 			func(in refreshSessionWithProviderTableInput) { | ||||||
|  | 				saved := false | ||||||
|  | 
 | ||||||
|  | 				s := &storedSessionLoader{ | ||||||
|  | 					store: &fakeSessionStore{ | ||||||
|  | 						SaveFunc: func(_ http.ResponseWriter, _ *http.Request, ss *sessionsapi.SessionState) error { | ||||||
|  | 							saved = true | ||||||
|  | 							if ss.AccessToken == "NoSave" { | ||||||
|  | 								return errors.New("unable to save session") | ||||||
|  | 							} | ||||||
|  | 							return nil | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
|  | 						switch ss.RefreshToken { | ||||||
|  | 						case refresh: | ||||||
|  | 							return true, nil | ||||||
|  | 						case noRefresh: | ||||||
|  | 							return false, nil | ||||||
|  | 						default: | ||||||
|  | 							return false, errors.New("error refreshing session") | ||||||
|  | 						} | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				req := httptest.NewRequest("", "/", nil) | ||||||
|  | 				refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) | ||||||
|  | 				if in.expectedErr != nil { | ||||||
|  | 					Expect(err).To(MatchError(in.expectedErr)) | ||||||
|  | 				} else { | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				} | ||||||
|  | 				Expect(refreshed).To(Equal(in.expectRefreshed)) | ||||||
|  | 				Expect(saved).To(Equal(in.expectSaved)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: noRefresh, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectSaved:     false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: true, | ||||||
|  | 				expectSaved:     true, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: "RefreshError", | ||||||
|  | 					CreatedAt:    &now, | ||||||
|  | 					ExpiresOn:    &now, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectSaved:     false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					AccessToken:  "NoSave", | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     errors.New("error saving session: unable to save session"), | ||||||
|  | 				expectRefreshed: false, | ||||||
|  | 				expectSaved:     true, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("validateSession", func() { | ||||||
|  | 		var s *storedSessionLoader | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			s = &storedSessionLoader{ | ||||||
|  | 				validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||||
|  | 					return ss.AccessToken == "Valid" | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with a valid session", func() { | ||||||
|  | 			It("does not return an error", func() { | ||||||
|  | 				expires := time.Now().Add(1 * time.Minute) | ||||||
|  | 				session := &sessionsapi.SessionState{ | ||||||
|  | 					AccessToken: "Valid", | ||||||
|  | 					ExpiresOn:   &expires, | ||||||
|  | 				} | ||||||
|  | 				Expect(s.validateSession(ctx, session)).To(Succeed()) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with an expired session", func() { | ||||||
|  | 			It("returns an error", func() { | ||||||
|  | 				created := time.Now().Add(-5 * time.Minute) | ||||||
|  | 				expires := time.Now().Add(-1 * time.Minute) | ||||||
|  | 				session := &sessionsapi.SessionState{ | ||||||
|  | 					AccessToken: "Valid", | ||||||
|  | 					CreatedAt:   &created, | ||||||
|  | 					ExpiresOn:   &expires, | ||||||
|  | 				} | ||||||
|  | 				Expect(s.validateSession(ctx, session)).To(MatchError("session is expired")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with an invalid session", func() { | ||||||
|  | 			It("returns an error", func() { | ||||||
|  | 				expires := time.Now().Add(1 * time.Minute) | ||||||
|  | 				session := &sessionsapi.SessionState{ | ||||||
|  | 					AccessToken: "Invalid", | ||||||
|  | 					ExpiresOn:   &expires, | ||||||
|  | 				} | ||||||
|  | 				Expect(s.validateSession(ctx, session)).To(MatchError("session is invalid")) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | 
 | ||||||
|  | type fakeSessionStore struct { | ||||||
|  | 	SaveFunc  func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error | ||||||
|  | 	LoadFunc  func(req *http.Request) (*sessionsapi.SessionState, error) | ||||||
|  | 	ClearFunc func(rw http.ResponseWriter, req *http.Request) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (f *fakeSessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { | ||||||
|  | 	if f.SaveFunc != nil { | ||||||
|  | 		return f.SaveFunc(rw, req, s) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 	if f.LoadFunc != nil { | ||||||
|  | 		return f.LoadFunc(req) | ||||||
|  | 	} | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||||
|  | 	if f.ClearFunc != nil { | ||||||
|  | 		return f.ClearFunc(rw, req) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue