Replace pkg/clock with narrowly targeted stub clocks.
The package under pkg/clock is github.com/benbjohnson/clock, which is archived. It's also way more complex than is what is actually needed here, so we can entirely remove the dependency and remove the helper package. Fixes #2840. Signed-off-by: David Symonds <dsymonds@gmail.com>
This commit is contained in:
		
							parent
							
								
									9168731c7a
								
							
						
					
					
						commit
						27f8b2970f
					
				
							
								
								
									
										1
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										1
									
								
								go.mod
								
								
								
								
							|  | @ -7,7 +7,6 @@ require ( | |||
| 	github.com/Bose/minisentinel v0.0.0-20200130220412-917c5a9223bb | ||||
| 	github.com/a8m/envsubst v1.4.3 | ||||
| 	github.com/alicebob/miniredis/v2 v2.35.0 | ||||
| 	github.com/benbjohnson/clock v1.3.5 | ||||
| 	github.com/bitly/go-simplejson v0.5.1 | ||||
| 	github.com/bsm/redislock v0.9.4 | ||||
| 	github.com/coreos/go-oidc/v3 v3.14.1 | ||||
|  |  | |||
							
								
								
									
										2
									
								
								go.sum
								
								
								
								
							
							
						
						
									
										2
									
								
								go.sum
								
								
								
								
							|  | @ -14,8 +14,6 @@ github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGn | |||
| github.com/alicebob/miniredis/v2 v2.11.1/go.mod h1:UA48pmi7aSazcGAvcdKcBB49z521IC9VjTTRz2nIaJE= | ||||
| github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= | ||||
| github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= | ||||
| github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= | ||||
| github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= | ||||
| github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= | ||||
| github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= | ||||
| github.com/bitly/go-simplejson v0.5.1 h1:xgwPbetQScXt1gh9BmoJ6j9JMr3TElvuIyjR8pgdoow= | ||||
|  |  | |||
|  | @ -7,7 +7,6 @@ import ( | |||
| 	"io" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||
| 	"github.com/pierrec/lz4/v4" | ||||
| 	"github.com/vmihailenco/msgpack/v5" | ||||
|  | @ -30,10 +29,17 @@ type SessionState struct { | |||
| 	PreferredUsername string   `msgpack:"pu,omitempty"` | ||||
| 
 | ||||
| 	// Internal helpers, not serialized
 | ||||
| 	Clock clock.Clock `msgpack:"-"` | ||||
| 	Clock func() time.Time `msgpack:"-"` // override for time.Now, for testing
 | ||||
| 	Lock  Lock             `msgpack:"-"` | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) now() time.Time { | ||||
| 	if s.Clock != nil { | ||||
| 		return s.Clock() | ||||
| 	} | ||||
| 	return time.Now() | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error { | ||||
| 	if s.Lock == nil { | ||||
| 		s.Lock = &NoOpLock{} | ||||
|  | @ -64,7 +70,7 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) { | |||
| 
 | ||||
| // CreatedAtNow sets a SessionState's CreatedAt to now
 | ||||
| func (s *SessionState) CreatedAtNow() { | ||||
| 	now := s.Clock.Now() | ||||
| 	now := s.now() | ||||
| 	s.CreatedAt = &now | ||||
| } | ||||
| 
 | ||||
|  | @ -85,7 +91,7 @@ func (s *SessionState) ExpiresIn(d time.Duration) { | |||
| 
 | ||||
| // IsExpired checks whether the session has expired
 | ||||
| func (s *SessionState) IsExpired() bool { | ||||
| 	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) { | ||||
| 	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.now()) { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
|  | @ -94,7 +100,7 @@ func (s *SessionState) IsExpired() bool { | |||
| // Age returns the age of a session
 | ||||
| func (s *SessionState) Age() time.Duration { | ||||
| 	if s.CreatedAt != nil && !s.CreatedAt.IsZero() { | ||||
| 		return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt) | ||||
| 		return s.now().Truncate(time.Second).Sub(*s.CreatedAt) | ||||
| 	} | ||||
| 	return 0 | ||||
| } | ||||
|  |  | |||
|  | @ -22,7 +22,7 @@ func TestCreatedAtNow(t *testing.T) { | |||
| 	ss := &SessionState{} | ||||
| 
 | ||||
| 	now := time.Unix(1234567890, 0) | ||||
| 	ss.Clock.Set(now) | ||||
| 	ss.Clock = func() time.Time { return now } | ||||
| 
 | ||||
| 	ss.CreatedAtNow() | ||||
| 	g.Expect(*ss.CreatedAt).To(Equal(now)) | ||||
|  | @ -33,9 +33,9 @@ func TestExpiresIn(t *testing.T) { | |||
| 	ss := &SessionState{} | ||||
| 
 | ||||
| 	now := time.Unix(1234567890, 0) | ||||
| 	ss.Clock.Set(now) | ||||
| 	ss.Clock = func() time.Time { return now } | ||||
| 
 | ||||
| 	ttl := time.Duration(743) * time.Second | ||||
| 	ttl := 743 * time.Second | ||||
| 	ss.ExpiresIn(ttl) | ||||
| 
 | ||||
| 	g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl))) | ||||
|  |  | |||
|  | @ -1,157 +0,0 @@ | |||
| package clock | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	clockapi "github.com/benbjohnson/clock" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	globalClock = clockapi.New() | ||||
| 	mu          sync.Mutex | ||||
| ) | ||||
| 
 | ||||
| // Set the global clock to a clockapi.Mock with the given time.Time
 | ||||
| func Set(t time.Time) { | ||||
| 	mu.Lock() | ||||
| 	defer mu.Unlock() | ||||
| 	mock, ok := globalClock.(*clockapi.Mock) | ||||
| 	if !ok { | ||||
| 		mock = clockapi.NewMock() | ||||
| 	} | ||||
| 	mock.Set(t) | ||||
| 	globalClock = mock | ||||
| } | ||||
| 
 | ||||
| // Add moves the mocked global clock forward the given duration. It will error
 | ||||
| // if the global clock is not mocked.
 | ||||
| func Add(d time.Duration) error { | ||||
| 	mu.Lock() | ||||
| 	defer mu.Unlock() | ||||
| 	mock, ok := globalClock.(*clockapi.Mock) | ||||
| 	if !ok { | ||||
| 		return errors.New("time not mocked") | ||||
| 	} | ||||
| 	mock.Add(d) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Reset sets the global clock to a pure time implementation. Returns any
 | ||||
| // existing Mock if set in case lingering time operations are attached to it.
 | ||||
| func Reset() *clockapi.Mock { | ||||
| 	mu.Lock() | ||||
| 	defer mu.Unlock() | ||||
| 	existing := globalClock | ||||
| 	globalClock = clockapi.New() | ||||
| 
 | ||||
| 	mock, ok := existing.(*clockapi.Mock) | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return mock | ||||
| } | ||||
| 
 | ||||
| // Clock is a non-package level wrapper around time that supports stubbing.
 | ||||
| // It will use its localized stubs (allowing for parallelized unit tests
 | ||||
| // where package level stubbing would cause issues). It falls back to any
 | ||||
| // package level time stubs for non-parallel, cross-package integration
 | ||||
| // testing scenarios.
 | ||||
| //
 | ||||
| // If nothing is stubbed, it defaults to default time behavior in the time
 | ||||
| // package.
 | ||||
| type Clock struct { | ||||
| 	mock *clockapi.Mock | ||||
| } | ||||
| 
 | ||||
| // Set sets the Clock to a clock.Mock at the given time.Time
 | ||||
| func (c *Clock) Set(t time.Time) { | ||||
| 	if c.mock == nil { | ||||
| 		c.mock = clockapi.NewMock() | ||||
| 	} | ||||
| 	c.mock.Set(t) | ||||
| } | ||||
| 
 | ||||
| // Add moves clock forward time.Duration if it is mocked. It will error
 | ||||
| // if the clock is not mocked.
 | ||||
| func (c *Clock) Add(d time.Duration) error { | ||||
| 	if c.mock == nil { | ||||
| 		return errors.New("clock not mocked") | ||||
| 	} | ||||
| 	c.mock.Add(d) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Reset removes local clock.Mock.  Returns any existing Mock if set in case
 | ||||
| // lingering time operations are attached to it.
 | ||||
| func (c *Clock) Reset() *clockapi.Mock { | ||||
| 	existing := c.mock | ||||
| 	c.mock = nil | ||||
| 	return existing | ||||
| } | ||||
| 
 | ||||
| func (c *Clock) After(d time.Duration) <-chan time.Time { | ||||
| 	m := c.mock | ||||
| 	if m == nil { | ||||
| 		return globalClock.After(d) | ||||
| 	} | ||||
| 	return m.After(d) | ||||
| } | ||||
| 
 | ||||
| func (c *Clock) AfterFunc(d time.Duration, f func()) *clockapi.Timer { | ||||
| 	m := c.mock | ||||
| 	if m == nil { | ||||
| 		return globalClock.AfterFunc(d, f) | ||||
| 	} | ||||
| 	return m.AfterFunc(d, f) | ||||
| } | ||||
| 
 | ||||
| func (c *Clock) Now() time.Time { | ||||
| 	m := c.mock | ||||
| 	if m == nil { | ||||
| 		return globalClock.Now() | ||||
| 	} | ||||
| 	return m.Now() | ||||
| } | ||||
| 
 | ||||
| func (c *Clock) Since(t time.Time) time.Duration { | ||||
| 	m := c.mock | ||||
| 	if m == nil { | ||||
| 		return globalClock.Since(t) | ||||
| 	} | ||||
| 	return m.Since(t) | ||||
| } | ||||
| 
 | ||||
| func (c *Clock) Sleep(d time.Duration) { | ||||
| 	m := c.mock | ||||
| 	if m == nil { | ||||
| 		globalClock.Sleep(d) | ||||
| 		return | ||||
| 	} | ||||
| 	m.Sleep(d) | ||||
| } | ||||
| 
 | ||||
| func (c *Clock) Tick(d time.Duration) <-chan time.Time { | ||||
| 	m := c.mock | ||||
| 	if m == nil { | ||||
| 		return globalClock.Tick(d) | ||||
| 	} | ||||
| 	return m.Tick(d) | ||||
| } | ||||
| 
 | ||||
| func (c *Clock) Ticker(d time.Duration) *clockapi.Ticker { | ||||
| 	m := c.mock | ||||
| 	if m == nil { | ||||
| 		return globalClock.Ticker(d) | ||||
| 	} | ||||
| 	return m.Ticker(d) | ||||
| } | ||||
| 
 | ||||
| func (c *Clock) Timer(d time.Duration) *clockapi.Timer { | ||||
| 	m := c.mock | ||||
| 	if m == nil { | ||||
| 		return globalClock.Timer(d) | ||||
| 	} | ||||
| 	return m.Timer(d) | ||||
| } | ||||
|  | @ -1,17 +0,0 @@ | |||
| package clock_test | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	. "github.com/onsi/ginkgo/v2" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| func TestClockSuite(t *testing.T) { | ||||
| 	logger.SetOutput(GinkgoWriter) | ||||
| 	logger.SetErrOutput(GinkgoWriter) | ||||
| 
 | ||||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "Clock") | ||||
| } | ||||
|  | @ -1,380 +0,0 @@ | |||
| package clock_test | ||||
| 
 | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" | ||||
| 	. "github.com/onsi/ginkgo/v2" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	testGlobalEpoch = 1000000000 | ||||
| 	testLocalEpoch  = 1234567890 | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Clock suite", func() { | ||||
| 	var testClock = clock.Clock{} | ||||
| 
 | ||||
| 	AfterEach(func() { | ||||
| 		clock.Reset() | ||||
| 		testClock.Reset() | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("Global time not overridden", func() { | ||||
| 		It("errors when trying to Add", func() { | ||||
| 			err := clock.Add(123 * time.Hour) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("Clock not mocked via Set", func() { | ||||
| 			const ( | ||||
| 				outsideTolerance = int32(0) | ||||
| 				withinTolerance  = int32(1) | ||||
| 			) | ||||
| 
 | ||||
| 			It("uses time.After for After", func() { | ||||
| 				var tolerance int32 | ||||
| 				go func() { | ||||
| 					time.Sleep(10 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, withinTolerance) | ||||
| 				}() | ||||
| 				go func() { | ||||
| 					time.Sleep(30 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, outsideTolerance) | ||||
| 				}() | ||||
| 
 | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 
 | ||||
| 				<-testClock.After(20 * time.Millisecond) | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) | ||||
| 
 | ||||
| 				<-testClock.After(20 * time.Millisecond) | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("uses time.AfterFunc for AfterFunc", func() { | ||||
| 				var tolerance int32 | ||||
| 				go func() { | ||||
| 					time.Sleep(10 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, withinTolerance) | ||||
| 				}() | ||||
| 				go func() { | ||||
| 					time.Sleep(30 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, outsideTolerance) | ||||
| 				}() | ||||
| 
 | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 
 | ||||
| 				var wg sync.WaitGroup | ||||
| 				wg.Add(1) | ||||
| 				testClock.AfterFunc(20*time.Millisecond, func() { | ||||
| 					wg.Done() | ||||
| 				}) | ||||
| 				wg.Wait() | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) | ||||
| 
 | ||||
| 				wg.Add(1) | ||||
| 				testClock.AfterFunc(20*time.Millisecond, func() { | ||||
| 					wg.Done() | ||||
| 				}) | ||||
| 				wg.Wait() | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("uses time.Now for Now", func() { | ||||
| 				a := time.Now() | ||||
| 				b := testClock.Now() | ||||
| 				Expect(b.Sub(a).Round(10 * time.Millisecond)).To(Equal(0 * time.Millisecond)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("uses time.Since for Since", func() { | ||||
| 				past := time.Now().Add(-60 * time.Second) | ||||
| 				Expect(time.Since(past).Round(10 * time.Millisecond)). | ||||
| 					To(Equal(60 * time.Second)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("uses time.Sleep for Sleep", func() { | ||||
| 				var tolerance int32 | ||||
| 				go func() { | ||||
| 					time.Sleep(10 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, withinTolerance) | ||||
| 				}() | ||||
| 				go func() { | ||||
| 					time.Sleep(30 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, outsideTolerance) | ||||
| 				}() | ||||
| 
 | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 
 | ||||
| 				testClock.Sleep(20 * time.Millisecond) | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) | ||||
| 
 | ||||
| 				testClock.Sleep(20 * time.Millisecond) | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("uses time.Tick for Tick", func() { | ||||
| 				var tolerance int32 | ||||
| 				go func() { | ||||
| 					time.Sleep(10 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, withinTolerance) | ||||
| 				}() | ||||
| 				go func() { | ||||
| 					time.Sleep(50 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, outsideTolerance) | ||||
| 				}() | ||||
| 
 | ||||
| 				ch := testClock.Tick(20 * time.Millisecond) | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 				<-ch | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) | ||||
| 				<-ch | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) | ||||
| 				<-ch | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("uses time.Ticker for Ticker", func() { | ||||
| 				var tolerance int32 | ||||
| 				go func() { | ||||
| 					time.Sleep(10 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, withinTolerance) | ||||
| 				}() | ||||
| 				go func() { | ||||
| 					time.Sleep(50 * time.Millisecond) | ||||
| 					atomic.StoreInt32(&tolerance, outsideTolerance) | ||||
| 				}() | ||||
| 
 | ||||
| 				ticker := testClock.Ticker(20 * time.Millisecond) | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 				<-ticker.C | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) | ||||
| 				<-ticker.C | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) | ||||
| 				<-ticker.C | ||||
| 				Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("errors if Add is used", func() { | ||||
| 				err := testClock.Add(100 * time.Second) | ||||
| 				Expect(err).To(HaveOccurred()) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("Clock mocked via Set", func() { | ||||
| 			var now = time.Unix(testLocalEpoch, 0) | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				testClock.Set(now) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks After", func() { | ||||
| 				var after int32 | ||||
| 				ready := make(chan struct{}) | ||||
| 				ch := testClock.After(10 * time.Second) | ||||
| 				go func(ch <-chan time.Time) { | ||||
| 					close(ready) | ||||
| 					<-ch | ||||
| 					atomic.StoreInt32(&after, 1) | ||||
| 				}(ch) | ||||
| 				<-ready | ||||
| 
 | ||||
| 				err := testClock.Add(9 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) | ||||
| 
 | ||||
| 				err = testClock.Add(1 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(1))) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks AfterFunc", func() { | ||||
| 				var after int32 | ||||
| 				testClock.AfterFunc(10*time.Second, func() { | ||||
| 					atomic.StoreInt32(&after, 1) | ||||
| 				}) | ||||
| 
 | ||||
| 				err := testClock.Add(9 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) | ||||
| 
 | ||||
| 				err = testClock.Add(1 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(1))) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks AfterFunc with a stopped timer", func() { | ||||
| 				var after int32 | ||||
| 				timer := testClock.AfterFunc(10*time.Second, func() { | ||||
| 					atomic.StoreInt32(&after, 1) | ||||
| 				}) | ||||
| 				timer.Stop() | ||||
| 
 | ||||
| 				err := testClock.Add(11 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks Now", func() { | ||||
| 				Expect(testClock.Now()).To(Equal(now)) | ||||
| 				err := testClock.Add(123 * time.Hour) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(testClock.Now()).To(Equal(now.Add(123 * time.Hour))) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks Since", func() { | ||||
| 				Expect(testClock.Since(time.Unix(testLocalEpoch-100, 0))). | ||||
| 					To(Equal(100 * time.Second)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks Sleep", func() { | ||||
| 				var after int32 | ||||
| 				ready := make(chan struct{}) | ||||
| 				go func() { | ||||
| 					close(ready) | ||||
| 					testClock.Sleep(10 * time.Second) | ||||
| 					atomic.StoreInt32(&after, 1) | ||||
| 				}() | ||||
| 				<-ready | ||||
| 
 | ||||
| 				err := testClock.Add(9 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) | ||||
| 
 | ||||
| 				err = testClock.Add(1 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(1))) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks Tick", func() { | ||||
| 				var ticks int32 | ||||
| 				ready := make(chan struct{}) | ||||
| 				go func() { | ||||
| 					close(ready) | ||||
| 					tick := testClock.Tick(10 * time.Second) | ||||
| 					for ticks < 5 { | ||||
| 						<-tick | ||||
| 						atomic.AddInt32(&ticks, 1) | ||||
| 					} | ||||
| 				}() | ||||
| 				<-ready | ||||
| 
 | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0))) | ||||
| 
 | ||||
| 				err := testClock.Add(9 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0))) | ||||
| 
 | ||||
| 				err = testClock.Add(1 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(1))) | ||||
| 
 | ||||
| 				err = testClock.Add(30 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(4))) | ||||
| 
 | ||||
| 				err = testClock.Add(10 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(5))) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks Ticker", func() { | ||||
| 				var ticks int32 | ||||
| 				ready := make(chan struct{}) | ||||
| 				go func() { | ||||
| 					ticker := testClock.Ticker(10 * time.Second) | ||||
| 					close(ready) | ||||
| 					for ticks < 5 { | ||||
| 						<-ticker.C | ||||
| 						atomic.AddInt32(&ticks, 1) | ||||
| 					} | ||||
| 				}() | ||||
| 				<-ready | ||||
| 
 | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0))) | ||||
| 
 | ||||
| 				err := testClock.Add(9 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0))) | ||||
| 
 | ||||
| 				err = testClock.Add(1 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(1))) | ||||
| 
 | ||||
| 				err = testClock.Add(30 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(4))) | ||||
| 
 | ||||
| 				err = testClock.Add(10 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(5))) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("mocks Timer", func() { | ||||
| 				var after int32 | ||||
| 				ready := make(chan struct{}) | ||||
| 				go func() { | ||||
| 					timer := testClock.Timer(10 * time.Second) | ||||
| 					close(ready) | ||||
| 					<-timer.C | ||||
| 					atomic.AddInt32(&after, 1) | ||||
| 				}() | ||||
| 				<-ready | ||||
| 
 | ||||
| 				err := testClock.Add(9 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) | ||||
| 
 | ||||
| 				err = testClock.Add(1 * time.Second) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(atomic.LoadInt32(&after)).To(Equal(int32(1))) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("Global time overridden", func() { | ||||
| 		var ( | ||||
| 			globalNow = time.Unix(testGlobalEpoch, 0) | ||||
| 			localNow  = time.Unix(testLocalEpoch, 0) | ||||
| 		) | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			clock.Set(globalNow) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("Clock not mocked via Set", func() { | ||||
| 			It("uses globally mocked Now", func() { | ||||
| 				Expect(testClock.Now()).To(Equal(globalNow)) | ||||
| 				err := clock.Add(123 * time.Hour) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(testClock.Now()).To(Equal(globalNow.Add(123 * time.Hour))) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("errors when Add is called on the local Clock", func() { | ||||
| 				err := testClock.Add(100 * time.Hour) | ||||
| 				Expect(err).To(HaveOccurred()) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("Clock is mocked via Set", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				testClock.Set(localNow) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("uses the local mock and ignores the global", func() { | ||||
| 				Expect(testClock.Now()).To(Equal(localNow)) | ||||
| 
 | ||||
| 				err := clock.Add(456 * time.Hour) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				err = testClock.Add(123 * time.Hour) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				Expect(testClock.Now()).To(Equal(localNow.Add(123 * time.Hour))) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
|  | @ -10,7 +10,6 @@ import ( | |||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||
| 	"github.com/vmihailenco/msgpack/v5" | ||||
| ) | ||||
|  | @ -47,7 +46,7 @@ type csrf struct { | |||
| 	CodeVerifier string `msgpack:"cv,omitempty"` | ||||
| 
 | ||||
| 	cookieOpts *options.Cookie | ||||
| 	time       clock.Clock | ||||
| 	now        func() time.Time | ||||
| } | ||||
| 
 | ||||
| // csrtStateTrim will indicate the length of the state trimmed for the name of the csrf cookie
 | ||||
|  | @ -70,6 +69,7 @@ func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) { | |||
| 		CodeVerifier: codeVerifier, | ||||
| 
 | ||||
| 		cookieOpts: opts, | ||||
| 		now:        time.Now, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -187,7 +187,7 @@ func ClearExtraCsrfCookies(opts *options.Cookie, rw http.ResponseWriter, req *ht | |||
| 
 | ||||
| 	// delete the X oldest cookies
 | ||||
| 	slices.SortStableFunc(decodedCookies, func(a, b *csrf) int { | ||||
| 		return a.time.Now().Compare(b.time.Now()) | ||||
| 		return a.now().Compare(b.now()) | ||||
| 	}) | ||||
| 
 | ||||
| 	for i := 0; i < len(decodedCookies)-opts.CSRFPerRequestLimit; i++ { | ||||
|  | @ -223,7 +223,7 @@ func (c *csrf) encodeCookie() (string, error) { | |||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("error getting cookie secret: %v", err) | ||||
| 	} | ||||
| 	return encryption.SignedValue(secret, c.cookieName(), encrypted, c.time.Now()) | ||||
| 	return encryption.SignedValue(secret, c.cookieName(), encrypted, c.now()) | ||||
| } | ||||
| 
 | ||||
| // decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
 | ||||
|  | @ -249,10 +249,10 @@ func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error) | |||
| 
 | ||||
| // unmarshalCSRF unmarshals decrypted data into a CSRF struct
 | ||||
| func unmarshalCSRF(decrypted []byte, opts *options.Cookie, csrfTime time.Time) (*csrf, error) { | ||||
| 	clock := clock.Clock{} | ||||
| 	clock.Set(csrfTime) | ||||
| 
 | ||||
| 	csrf := &csrf{cookieOpts: opts, time: clock} | ||||
| 	csrf := &csrf{ | ||||
| 		cookieOpts: opts, | ||||
| 		now:        func() time.Time { return csrfTime }, | ||||
| 	} | ||||
| 	if err := msgpack.Unmarshal(decrypted, csrf); err != nil { | ||||
| 		return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err) | ||||
| 	} | ||||
|  |  | |||
|  | @ -128,7 +128,7 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() { | |||
| 		testNow := time.Unix(nowEpoch, 0) | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			privateCSRF.time.Set(testNow) | ||||
| 			privateCSRF.now = func() time.Time { return testNow } | ||||
| 
 | ||||
| 			req = &http.Request{ | ||||
| 				Method: http.MethodGet, | ||||
|  | @ -144,7 +144,7 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() { | |||
| 		}) | ||||
| 
 | ||||
| 		AfterEach(func() { | ||||
| 			privateCSRF.time.Reset() | ||||
| 			privateCSRF.now = time.Now | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("SetCookie", func() { | ||||
|  | @ -200,17 +200,17 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() { | |||
| 				publicCSRF1, err := NewCSRF(cookieOpts, "verifier") | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				privateCSRF1 := publicCSRF1.(*csrf) | ||||
| 				privateCSRF1.time.Set(testNow) | ||||
| 				privateCSRF1.now = func() time.Time { return testNow } | ||||
| 
 | ||||
| 				publicCSRF2, err := NewCSRF(cookieOpts, "verifier") | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				privateCSRF2 := publicCSRF2.(*csrf) | ||||
| 				privateCSRF2.time.Set(testNow.Add(time.Minute)) | ||||
| 				privateCSRF2.now = func() time.Time { return testNow.Add(time.Minute) } | ||||
| 
 | ||||
| 				publicCSRF3, err := NewCSRF(cookieOpts, "verifier") | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				privateCSRF3 := publicCSRF3.(*csrf) | ||||
| 				privateCSRF3.time.Set(testNow.Add(time.Minute * 2)) | ||||
| 				privateCSRF3.now = func() time.Time { return testNow.Add(time.Minute * 2) } | ||||
| 
 | ||||
| 				cookies := []string{} | ||||
| 				for _, csrf := range []*csrf{privateCSRF1, privateCSRF2, privateCSRF3} { | ||||
|  |  | |||
|  | @ -130,7 +130,7 @@ var _ = Describe("CSRF Cookie Tests", func() { | |||
| 		testNow := time.Unix(nowEpoch, 0) | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			privateCSRF.time.Set(testNow) | ||||
| 			privateCSRF.now = func() time.Time { return testNow } | ||||
| 
 | ||||
| 			req = &http.Request{ | ||||
| 				Method: http.MethodGet, | ||||
|  | @ -146,7 +146,7 @@ var _ = Describe("CSRF Cookie Tests", func() { | |||
| 		}) | ||||
| 
 | ||||
| 		AfterEach(func() { | ||||
| 			privateCSRF.time.Reset() | ||||
| 			privateCSRF.now = time.Now | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("SetCookie", func() { | ||||
|  | @ -173,7 +173,7 @@ var _ = Describe("CSRF Cookie Tests", func() { | |||
| 		Context("LoadCSRFCookie", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				// we need to reset the time to ensure the cookie is valid
 | ||||
| 				privateCSRF.time.Reset() | ||||
| 				privateCSRF.now = time.Now | ||||
| 			}) | ||||
| 
 | ||||
| 			It("should return error when no cookie is set", func() { | ||||
|  |  | |||
|  | @ -11,7 +11,6 @@ import ( | |||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/providers" | ||||
| 	. "github.com/onsi/ginkgo/v2" | ||||
| 	. "github.com/onsi/gomega" | ||||
|  | @ -95,6 +94,7 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 		now := time.Now() | ||||
| 		createdPast := now.Add(-5 * time.Minute) | ||||
| 		createdFuture := now.Add(5 * time.Minute) | ||||
| 		clock := func() time.Time { return now } | ||||
| 
 | ||||
| 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||
| 			switch ss.RefreshToken { | ||||
|  | @ -120,6 +120,7 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 						Clock:        clock, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=InvalidNoRefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
|  | @ -127,24 +128,28 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 						Clock:        clock, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=ExpiredNoRefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdPast, | ||||
| 						Clock:        clock, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=RefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: refresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 						Clock:        clock, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=RefreshError": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: "RefreshError", | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 						Clock:        clock, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=NonExistent": | ||||
| 					return nil, fmt.Errorf("invalid cookie") | ||||
|  | @ -154,14 +159,6 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 			}, | ||||
| 		} | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			clock.Set(now) | ||||
| 		}) | ||||
| 
 | ||||
| 		AfterEach(func() { | ||||
| 			clock.Reset() | ||||
| 		}) | ||||
| 
 | ||||
| 		type storedSessionLoaderTableInput struct { | ||||
| 			requestHeaders  http.Header | ||||
| 			existingSession *sessionsapi.SessionState | ||||
|  | @ -200,7 +197,15 @@ var _ = Describe("Stored Session Suite", func() { | |||
| 				})) | ||||
| 				handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 				Expect(gotSession).To(Equal(in.expectedSession)) | ||||
| 				// Compare, ignoring testing Clock.
 | ||||
| 				if in.expectedSession == nil { | ||||
| 					Expect(gotSession).To(BeNil()) | ||||
| 					return | ||||
| 				} | ||||
| 				Expect(gotSession).ToNot(BeNil()) | ||||
| 				got := *gotSession | ||||
| 				got.Clock = nil | ||||
| 				Expect(&got).To(Equal(in.expectedSession)) | ||||
| 			}, | ||||
| 			Entry("with no cookie", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders:  http.Header{}, | ||||
|  |  | |||
|  | @ -17,8 +17,9 @@ func TestRefresh(t *testing.T) { | |||
| 	now := time.Unix(1234567890, 10) | ||||
| 	expires := time.Unix(1234567890, 0) | ||||
| 
 | ||||
| 	ss := &sessions.SessionState{} | ||||
| 	ss.Clock.Set(now) | ||||
| 	ss := &sessions.SessionState{ | ||||
| 		Clock: func() time.Time { return now }, | ||||
| 	} | ||||
| 	ss.SetExpiresOn(expires) | ||||
| 
 | ||||
| 	refreshed, err := p.RefreshSession(context.Background(), ss) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue