Add session loader from session storage
This commit is contained in:
		
							parent
							
								
									7d6f2a3f45
								
							
						
					
					
						commit
						034f057b60
					
				|  | @ -0,0 +1,165 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
| // StoredSessionLoaderOptions cotnains all of the requirements to construct
 | ||||
| // a stored session loader.
 | ||||
| // All options must be provided.
 | ||||
| type StoredSessionLoaderOptions struct { | ||||
| 	// Session storage basckend
 | ||||
| 	SessionStore sessionsapi.SessionStore | ||||
| 
 | ||||
| 	// How often should sessions be refreshed
 | ||||
| 	RefreshPeriod time.Duration | ||||
| 
 | ||||
| 	// Provider based sesssion refreshing
 | ||||
| 	RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||
| 
 | ||||
| 	// Provider based session validation.
 | ||||
| 	// If the sesssion is older than `RefreshPeriod` but the provider doesn't
 | ||||
| 	// refresh it, we must re-validate using this validation.
 | ||||
| 	ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool | ||||
| } | ||||
| 
 | ||||
| // NewStoredSessionLoader creates a new storedSessionLoader which loads
 | ||||
| // sessions from the session store.
 | ||||
| // If no session is found, the request will be passed to the nex handler.
 | ||||
| // If a session was loader by a previous handler, it will not be replaced.
 | ||||
| func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { | ||||
| 	ss := &storedSessionLoader{ | ||||
| 		store:                              opts.SessionStore, | ||||
| 		refreshPeriod:                      opts.RefreshPeriod, | ||||
| 		refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, | ||||
| 		validateSessionState:               opts.ValidateSessionState, | ||||
| 	} | ||||
| 	return ss.loadSession | ||||
| } | ||||
| 
 | ||||
| // storedSessionLoader is responsible for loading sessions from cookie
 | ||||
| // identified sessions in the session store.
 | ||||
| type storedSessionLoader struct { | ||||
| 	store                              sessionsapi.SessionStore | ||||
| 	refreshPeriod                      time.Duration | ||||
| 	refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||
| 	validateSessionState               func(context.Context, *sessionsapi.SessionState) bool | ||||
| } | ||||
| 
 | ||||
| // loadSession attempts to load a session as identified by the request cookies.
 | ||||
| // If no session is found, the request will be passed to the nex handler.
 | ||||
| // If a session was loader by a previous handler, it will not be replaced.
 | ||||
| func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := GetRequestScope(req) | ||||
| 		// If scope is nil, this will panic.
 | ||||
| 		// A scope should always be injected before this handler is called.
 | ||||
| 		if scope.Session != nil { | ||||
| 			// The session was already loaded, pass to the next handler
 | ||||
| 			next.ServeHTTP(rw, req) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		session, err := s.getValidatedSession(rw, req) | ||||
| 		if err != nil { | ||||
| 			// In the case when there was an error loading the session,
 | ||||
| 			// we should clear the session
 | ||||
| 			logger.Printf("Error loading cookied session: %v, removing session", err) | ||||
| 			s.store.Clear(rw, req) | ||||
| 		} | ||||
| 
 | ||||
| 		// Add the session to the scope if it was found
 | ||||
| 		scope.Session = session | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // getValidatedSession is responsible for loading a session and making sure
 | ||||
| // that is is valid.
 | ||||
| func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	session, err := s.store.Load(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if session == nil { | ||||
| 		// No session was found in the storage, nothing more to do
 | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	err = s.refreshSessionIfNeeded(rw, req, session) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) | ||||
| 	} | ||||
| 
 | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| // refreshSessionIfNeeded will attempt to refresh a session if the session
 | ||||
| // is older than the refresh period.
 | ||||
| // It is assumed that if the provider refreshes the session, the session is now
 | ||||
| // valid.
 | ||||
| // If the session requires refreshing but the provider does not refresh it,
 | ||||
| // we must validate the session to ensure that the returned session is still
 | ||||
| // valid.
 | ||||
| func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||
| 	if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { | ||||
| 		// Refresh is disabled or the session is not old enough, do nothing
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) | ||||
| 	refreshed, err := s.refreshSessionWithProvider(rw, req, session) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if !refreshed { | ||||
| 		// Session wasn't refreshed, so make sure it's still valid
 | ||||
| 		return s.validateSession(req.Context(), session) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // refreshSessionWithProvider attempts to refresh the sessinon with the provider
 | ||||
| // and will save the session if it was updated.
 | ||||
| func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { | ||||
| 	refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) | ||||
| 	if err != nil { | ||||
| 		return false, fmt.Errorf("error refreshing access token: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if !refreshed { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// Because the session was refreshed, make sure to save it
 | ||||
| 	err = s.store.Save(rw, req, session) | ||||
| 	if err != nil { | ||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | ||||
| 		return false, fmt.Errorf("error saving session: %v", err) | ||||
| 	} | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| // validateSession checks whether the session has expired and performs
 | ||||
| // provider validation on the session.
 | ||||
| // An error implies the session is not longer valid.
 | ||||
| func (s *storedSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error { | ||||
| 	if session.IsExpired() { | ||||
| 		return errors.New("session is expired") | ||||
| 	} | ||||
| 
 | ||||
| 	if !s.validateSessionState(ctx, session) { | ||||
| 		return errors.New("session is invalid") | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
|  | @ -0,0 +1,524 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"time" | ||||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Stored Session Suite", func() { | ||||
| 	const ( | ||||
| 		refresh   = "Refresh" | ||||
| 		noRefresh = "NoRefresh" | ||||
| 	) | ||||
| 
 | ||||
| 	var ctx = context.Background() | ||||
| 
 | ||||
| 	Context("StoredSessionLoader", func() { | ||||
| 		createdPast := time.Now().Add(-5 * time.Minute) | ||||
| 		createdFuture := time.Now().Add(5 * time.Minute) | ||||
| 
 | ||||
| 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||
| 			switch ss.RefreshToken { | ||||
| 			case refresh: | ||||
| 				ss.RefreshToken = "Refreshed" | ||||
| 				return true, nil | ||||
| 			case noRefresh: | ||||
| 				return false, nil | ||||
| 			default: | ||||
| 				return false, errors.New("error refreshing session") | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||
| 			return ss.AccessToken != "Invalid" | ||||
| 		} | ||||
| 
 | ||||
| 		var defaultSessionStore = &fakeSessionStore{ | ||||
| 			LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 				switch req.Header.Get("Cookie") { | ||||
| 				case "_oauth2_proxy=NoRefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=InvalidNoRefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						AccessToken:  "Invalid", | ||||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=ExpiredNoRefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdPast, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=RefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: refresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=RefreshError": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: "RefreshError", | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=NonExistent": | ||||
| 					return nil, fmt.Errorf("invalid cookie") | ||||
| 				default: | ||||
| 					return nil, nil | ||||
| 				} | ||||
| 			}, | ||||
| 		} | ||||
| 
 | ||||
| 		type storedSessionLoaderTableInput struct { | ||||
| 			requestHeaders  http.Header | ||||
| 			existingSession *sessionsapi.SessionState | ||||
| 			expectedSession *sessionsapi.SessionState | ||||
| 			store           sessionsapi.SessionStore | ||||
| 			refreshPeriod   time.Duration | ||||
| 			refreshSession  func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||
| 			validateSession func(context.Context, *sessionsapi.SessionState) bool | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("when serving a request", | ||||
| 			func(in storedSessionLoaderTableInput) { | ||||
| 				scope := &middlewareapi.RequestScope{ | ||||
| 					Session: in.existingSession, | ||||
| 				} | ||||
| 
 | ||||
| 				// Set up the request with the request headesr and a request scope
 | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				req.Header = in.requestHeaders | ||||
| 				contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) | ||||
| 				req = req.WithContext(contextWithScope) | ||||
| 
 | ||||
| 				rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 				opts := &StoredSessionLoaderOptions{ | ||||
| 					SessionStore:           in.store, | ||||
| 					RefreshPeriod:          in.refreshPeriod, | ||||
| 					RefreshSessionIfNeeded: in.refreshSession, | ||||
| 					ValidateSessionState:   in.validateSession, | ||||
| 				} | ||||
| 
 | ||||
| 				// Create the handler with a next handler that will capture the session
 | ||||
| 				// from the scope
 | ||||
| 				var gotSession *sessionsapi.SessionState | ||||
| 				handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 					gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session | ||||
| 				})) | ||||
| 				handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 				Expect(gotSession).To(Equal(in.expectedSession)) | ||||
| 			}, | ||||
| 			Entry("with no cookie", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders:  http.Header{}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with an invalid cookie", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=NonExistent"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with an existing session", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "Existing", | ||||
| 				}, | ||||
| 				expectedSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "Existing", | ||||
| 				}, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with a session that has not expired", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=NoRefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: noRefresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   10 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "Refreshed", | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("when the provider refresh fails", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshError"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("refreshSessionIfNeeded", func() { | ||||
| 		type refreshSessionIfNeededTableInput struct { | ||||
| 			refreshPeriod   time.Duration | ||||
| 			session         *sessionsapi.SessionState | ||||
| 			expectedErr     error | ||||
| 			expectRefreshed bool | ||||
| 			expectValidated bool | ||||
| 		} | ||||
| 
 | ||||
| 		createdPast := time.Now().Add(-5 * time.Minute) | ||||
| 		createdFuture := time.Now().Add(5 * time.Minute) | ||||
| 
 | ||||
| 		DescribeTable("with a session", | ||||
| 			func(in refreshSessionIfNeededTableInput) { | ||||
| 				refreshed := false | ||||
| 				validated := false | ||||
| 
 | ||||
| 				s := &storedSessionLoader{ | ||||
| 					refreshPeriod: in.refreshPeriod, | ||||
| 					store:         &fakeSessionStore{}, | ||||
| 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||
| 						refreshed = true | ||||
| 						switch ss.RefreshToken { | ||||
| 						case refresh: | ||||
| 							return true, nil | ||||
| 						case noRefresh: | ||||
| 							return false, nil | ||||
| 						default: | ||||
| 							return false, errors.New("error refreshing session") | ||||
| 						} | ||||
| 					}, | ||||
| 					validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||
| 						validated = true | ||||
| 						return ss.AccessToken != "Invalid" | ||||
| 					}, | ||||
| 				} | ||||
| 
 | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				err := s.refreshSessionIfNeeded(nil, req, in.session) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(refreshed).To(Equal(in.expectRefreshed)) | ||||
| 				Expect(validated).To(Equal(in.expectValidated)) | ||||
| 			}, | ||||
| 			Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: time.Duration(0), | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdFuture, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: false, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: time.Duration(0), | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: false, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdFuture, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: false, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: noRefresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: true, | ||||
| 			}), | ||||
| 			Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "RefreshError", | ||||
| 					CreatedAt:    &createdPast, | ||||
| 				}, | ||||
| 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					AccessToken:  "Invalid", | ||||
| 					RefreshToken: noRefresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				expectedErr:     errors.New("session is invalid"), | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: true, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("refreshSessionWithProvider", func() { | ||||
| 		type refreshSessionWithProviderTableInput struct { | ||||
| 			session         *sessionsapi.SessionState | ||||
| 			expectedErr     error | ||||
| 			expectRefreshed bool | ||||
| 			expectSaved     bool | ||||
| 		} | ||||
| 
 | ||||
| 		now := time.Now() | ||||
| 
 | ||||
| 		DescribeTable("when refreshing with the provider", | ||||
| 			func(in refreshSessionWithProviderTableInput) { | ||||
| 				saved := false | ||||
| 
 | ||||
| 				s := &storedSessionLoader{ | ||||
| 					store: &fakeSessionStore{ | ||||
| 						SaveFunc: func(_ http.ResponseWriter, _ *http.Request, ss *sessionsapi.SessionState) error { | ||||
| 							saved = true | ||||
| 							if ss.AccessToken == "NoSave" { | ||||
| 								return errors.New("unable to save session") | ||||
| 							} | ||||
| 							return nil | ||||
| 						}, | ||||
| 					}, | ||||
| 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||
| 						switch ss.RefreshToken { | ||||
| 						case refresh: | ||||
| 							return true, nil | ||||
| 						case noRefresh: | ||||
| 							return false, nil | ||||
| 						default: | ||||
| 							return false, errors.New("error refreshing session") | ||||
| 						} | ||||
| 					}, | ||||
| 				} | ||||
| 
 | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(refreshed).To(Equal(in.expectRefreshed)) | ||||
| 				Expect(saved).To(Equal(in.expectSaved)) | ||||
| 			}, | ||||
| 			Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: noRefresh, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: false, | ||||
| 				expectSaved:     false, | ||||
| 			}), | ||||
| 			Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: true, | ||||
| 				expectSaved:     true, | ||||
| 			}), | ||||
| 			Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "RefreshError", | ||||
| 					CreatedAt:    &now, | ||||
| 					ExpiresOn:    &now, | ||||
| 				}, | ||||
| 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | ||||
| 				expectRefreshed: false, | ||||
| 				expectSaved:     false, | ||||
| 			}), | ||||
| 			Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					AccessToken:  "NoSave", | ||||
| 				}, | ||||
| 				expectedErr:     errors.New("error saving session: unable to save session"), | ||||
| 				expectRefreshed: false, | ||||
| 				expectSaved:     true, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("validateSession", func() { | ||||
| 		var s *storedSessionLoader | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			s = &storedSessionLoader{ | ||||
| 				validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||
| 					return ss.AccessToken == "Valid" | ||||
| 				}, | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with a valid session", func() { | ||||
| 			It("does not return an error", func() { | ||||
| 				expires := time.Now().Add(1 * time.Minute) | ||||
| 				session := &sessionsapi.SessionState{ | ||||
| 					AccessToken: "Valid", | ||||
| 					ExpiresOn:   &expires, | ||||
| 				} | ||||
| 				Expect(s.validateSession(ctx, session)).To(Succeed()) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with an expired session", func() { | ||||
| 			It("returns an error", func() { | ||||
| 				created := time.Now().Add(-5 * time.Minute) | ||||
| 				expires := time.Now().Add(-1 * time.Minute) | ||||
| 				session := &sessionsapi.SessionState{ | ||||
| 					AccessToken: "Valid", | ||||
| 					CreatedAt:   &created, | ||||
| 					ExpiresOn:   &expires, | ||||
| 				} | ||||
| 				Expect(s.validateSession(ctx, session)).To(MatchError("session is expired")) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with an invalid session", func() { | ||||
| 			It("returns an error", func() { | ||||
| 				expires := time.Now().Add(1 * time.Minute) | ||||
| 				session := &sessionsapi.SessionState{ | ||||
| 					AccessToken: "Invalid", | ||||
| 					ExpiresOn:   &expires, | ||||
| 				} | ||||
| 				Expect(s.validateSession(ctx, session)).To(MatchError("session is invalid")) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
| 
 | ||||
| type fakeSessionStore struct { | ||||
| 	SaveFunc  func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error | ||||
| 	LoadFunc  func(req *http.Request) (*sessionsapi.SessionState, error) | ||||
| 	ClearFunc func(rw http.ResponseWriter, req *http.Request) error | ||||
| } | ||||
| 
 | ||||
| func (f *fakeSessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { | ||||
| 	if f.SaveFunc != nil { | ||||
| 		return f.SaveFunc(rw, req, s) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	if f.LoadFunc != nil { | ||||
| 		return f.LoadFunc(req) | ||||
| 	} | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||
| 	if f.ClearFunc != nil { | ||||
| 		return f.ClearFunc(rw, req) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
		Loading…
	
		Reference in New Issue