Add concurrent requests tests
This commit is contained in:
		
							parent
							
								
									c5ea345daf
								
							
						
					
					
						commit
						ad8ce2f6a4
					
				|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | @ -65,9 +66,50 @@ func (l *TestLock) Release(_ context.Context) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type LockConc struct { | ||||||
|  | 	mu          sync.Mutex | ||||||
|  | 	lock        bool | ||||||
|  | 	disablePeek bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LockConc) Obtain(_ context.Context, _ time.Duration) error { | ||||||
|  | 	l.mu.Lock() | ||||||
|  | 	if l.lock { | ||||||
|  | 		l.mu.Unlock() | ||||||
|  | 		return sessionsapi.ErrLockNotObtained | ||||||
|  | 	} | ||||||
|  | 	l.lock = true | ||||||
|  | 	l.mu.Unlock() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LockConc) Peek(_ context.Context) (bool, error) { | ||||||
|  | 	var response bool | ||||||
|  | 	l.mu.Lock() | ||||||
|  | 	if l.disablePeek { | ||||||
|  | 		response = false | ||||||
|  | 	} else { | ||||||
|  | 		response = l.lock | ||||||
|  | 	} | ||||||
|  | 	l.mu.Unlock() | ||||||
|  | 	return response, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LockConc) Refresh(_ context.Context, _ time.Duration) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LockConc) Release(_ context.Context) error { | ||||||
|  | 	l.mu.Lock() | ||||||
|  | 	l.lock = false | ||||||
|  | 	l.mu.Unlock() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| var _ = Describe("Stored Session Suite", func() { | var _ = Describe("Stored Session Suite", func() { | ||||||
| 	const ( | 	const ( | ||||||
| 		refresh        = "Refresh" | 		refresh        = "Refresh" | ||||||
|  | 		refreshed      = "Refreshed" | ||||||
| 		noRefresh      = "NoRefresh" | 		noRefresh      = "NoRefresh" | ||||||
| 		notImplemented = "NotImplemented" | 		notImplemented = "NotImplemented" | ||||||
| 	) | 	) | ||||||
|  | @ -82,7 +124,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
| 			switch ss.RefreshToken { | 			switch ss.RefreshToken { | ||||||
| 			case refresh: | 			case refresh: | ||||||
| 				ss.RefreshToken = "Refreshed" | 				ss.RefreshToken = refreshed | ||||||
| 				return true, nil | 				return true, nil | ||||||
| 			case noRefresh: | 			case noRefresh: | ||||||
| 				return false, nil | 				return false, nil | ||||||
|  | @ -317,6 +359,98 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				validateSession: defaultValidateFunc, | 				validateSession: defaultValidateFunc, | ||||||
| 			}), | 			}), | ||||||
| 		) | 		) | ||||||
|  | 
 | ||||||
|  | 		type storedSessionLoaderConcurrentTableInput struct { | ||||||
|  | 			existingSession *sessionsapi.SessionState | ||||||
|  | 			refreshPeriod   time.Duration | ||||||
|  | 			numConcReqs     int | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("when serving concurrent requests", | ||||||
|  | 			func(in storedSessionLoaderConcurrentTableInput) { | ||||||
|  | 				lockConc := &LockConc{} | ||||||
|  | 
 | ||||||
|  | 				refreshedChan := make(chan bool, in.numConcReqs) | ||||||
|  | 				for i := 0; i < in.numConcReqs; i++ { | ||||||
|  | 					go func(refreshedChan chan bool, lockConc sessionsapi.Lock) { | ||||||
|  | 						existingSession := *in.existingSession // deep copy existingSession state
 | ||||||
|  | 						existingSession.Lock = lockConc | ||||||
|  | 						store := &fakeSessionStore{ | ||||||
|  | 							LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
|  | 								return &existingSession, nil | ||||||
|  | 							}, | ||||||
|  | 							SaveFunc: func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error { | ||||||
|  | 								return nil | ||||||
|  | 							}, | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						scope := &middlewareapi.RequestScope{ | ||||||
|  | 							Session: nil, | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						// Set up the request with the request header and a request scope
 | ||||||
|  | 						req := httptest.NewRequest("", "/", nil) | ||||||
|  | 						req = middlewareapi.AddRequestScope(req, scope) | ||||||
|  | 
 | ||||||
|  | 						rw := httptest.NewRecorder() | ||||||
|  | 
 | ||||||
|  | 						sessionRefreshed := false | ||||||
|  | 						opts := &StoredSessionLoaderOptions{ | ||||||
|  | 							SessionStore:  store, | ||||||
|  | 							RefreshPeriod: in.refreshPeriod, | ||||||
|  | 							RefreshSession: func(ctx context.Context, s *sessionsapi.SessionState) (bool, error) { | ||||||
|  | 								time.Sleep(10 * time.Millisecond) | ||||||
|  | 								sessionRefreshed = true | ||||||
|  | 								return true, nil | ||||||
|  | 							}, | ||||||
|  | 							ValidateSession: func(context.Context, *sessionsapi.SessionState) bool { | ||||||
|  | 								return true | ||||||
|  | 							}, | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
|  | 						handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) | ||||||
|  | 						handler.ServeHTTP(rw, req) | ||||||
|  | 
 | ||||||
|  | 						refreshedChan <- sessionRefreshed | ||||||
|  | 					}(refreshedChan, lockConc) | ||||||
|  | 				} | ||||||
|  | 				var refreshedSlice []bool | ||||||
|  | 				for i := 0; i < in.numConcReqs; i++ { | ||||||
|  | 					refreshedSlice = append(refreshedSlice, <-refreshedChan) | ||||||
|  | 				} | ||||||
|  | 				sessionRefreshedCount := 0 | ||||||
|  | 				for _, sessionRefreshed := range refreshedSlice { | ||||||
|  | 					if sessionRefreshed { | ||||||
|  | 						sessionRefreshedCount++ | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				Expect(sessionRefreshedCount).To(Equal(1)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("with two concurrent requests", storedSessionLoaderConcurrentTableInput{ | ||||||
|  | 				existingSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				numConcReqs:   2, | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with 5 concurrent requests", storedSessionLoaderConcurrentTableInput{ | ||||||
|  | 				existingSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				numConcReqs:   5, | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with one request", storedSessionLoaderConcurrentTableInput{ | ||||||
|  | 				existingSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: refresh, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				numConcReqs:   1, | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	Context("refreshSessionIfNeeded", func() { | 	Context("refreshSessionIfNeeded", func() { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue