From 110d51d1d742b5cde9cb629d6912f47e5d9ab878 Mon Sep 17 00:00:00 2001 From: David Symonds Date: Tue, 28 Oct 2025 20:05:02 +1100 Subject: [PATCH] test: replace mock pkg/clock with narrowly targeted stub clocks. (#3238) The package under pkg/clock is github.com/benbjohnson/clock, which is archived. It's also way more complex than is what is actually needed here, so we can entirely remove the dependency and remove the helper package. Fixes #2840. Signed-off-by: David Symonds --- CHANGELOG.md | 1 + go.mod | 1 - go.sum | 2 - pkg/apis/sessions/session_state.go | 18 +- pkg/apis/sessions/session_state_test.go | 6 +- pkg/clock/clock.go | 157 ---------- pkg/clock/clock_suite_test.go | 17 -- pkg/clock/clock_test.go | 380 ------------------------ pkg/cookies/csrf.go | 16 +- pkg/cookies/csrf_per_request_test.go | 10 +- pkg/cookies/csrf_test.go | 6 +- pkg/middleware/stored_session_test.go | 25 +- providers/provider_default_test.go | 5 +- 13 files changed, 50 insertions(+), 594 deletions(-) delete mode 100644 pkg/clock/clock.go delete mode 100644 pkg/clock/clock_suite_test.go delete mode 100644 pkg/clock/clock_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index c52eedef..e014aee3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - [#3228](https://github.com/oauth2-proxy/oauth2-proxy/pull/3228) fix: use GetSecret() in ticket.go makeCookie to respect cookie-secret-file (@stagswtf) - [#3244](https://github.com/oauth2-proxy/oauth2-proxy/pull/3244) chore(deps): upgrade to latest go1.25.3 (@tuunit) +- [#3238](https://github.com/oauth2-proxy/oauth2-proxy/pull/3238) chore: Replace pkg/clock with narrowly targeted stub clocks (@dsymonds) # V7.12.0 diff --git a/go.mod b/go.mod index 3aeeda0a..0e14464c 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/Bose/minisentinel v0.0.0-20200130220412-917c5a9223bb github.com/a8m/envsubst v1.4.3 github.com/alicebob/miniredis/v2 v2.35.0 - github.com/benbjohnson/clock v1.3.5 github.com/bitly/go-simplejson v0.5.1 github.com/bsm/redislock v0.9.4 github.com/coreos/go-oidc/v3 v3.14.1 diff --git a/go.sum b/go.sum index 2e8db1ef..8bc31660 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,6 @@ github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGn github.com/alicebob/miniredis/v2 v2.11.1/go.mod h1:UA48pmi7aSazcGAvcdKcBB49z521IC9VjTTRz2nIaJE= github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= -github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= -github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bitly/go-simplejson v0.5.1 h1:xgwPbetQScXt1gh9BmoJ6j9JMr3TElvuIyjR8pgdoow= diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index b5e4fc83..5b063c3f 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -7,7 +7,6 @@ import ( "io" "time" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/pierrec/lz4/v4" "github.com/vmihailenco/msgpack/v5" @@ -30,8 +29,15 @@ type SessionState struct { PreferredUsername string `msgpack:"pu,omitempty"` // Internal helpers, not serialized - Clock clock.Clock `msgpack:"-"` - Lock Lock `msgpack:"-"` + Clock func() time.Time `msgpack:"-"` // override for time.Now, for testing + Lock Lock `msgpack:"-"` +} + +func (s *SessionState) now() time.Time { + if s.Clock != nil { + return s.Clock() + } + return time.Now() } func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error { @@ -64,7 +70,7 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) { // CreatedAtNow sets a SessionState's CreatedAt to now func (s *SessionState) CreatedAtNow() { - now := s.Clock.Now() + now := s.now() s.CreatedAt = &now } @@ -85,7 +91,7 @@ func (s *SessionState) ExpiresIn(d time.Duration) { // IsExpired checks whether the session has expired func (s *SessionState) IsExpired() bool { - if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) { + if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.now()) { return true } return false @@ -94,7 +100,7 @@ func (s *SessionState) IsExpired() bool { // Age returns the age of a session func (s *SessionState) Age() time.Duration { if s.CreatedAt != nil && !s.CreatedAt.IsZero() { - return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt) + return s.now().Truncate(time.Second).Sub(*s.CreatedAt) } return 0 } diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index e12c2776..442fcea8 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -22,7 +22,7 @@ func TestCreatedAtNow(t *testing.T) { ss := &SessionState{} now := time.Unix(1234567890, 0) - ss.Clock.Set(now) + ss.Clock = func() time.Time { return now } ss.CreatedAtNow() g.Expect(*ss.CreatedAt).To(Equal(now)) @@ -33,9 +33,9 @@ func TestExpiresIn(t *testing.T) { ss := &SessionState{} now := time.Unix(1234567890, 0) - ss.Clock.Set(now) + ss.Clock = func() time.Time { return now } - ttl := time.Duration(743) * time.Second + ttl := 743 * time.Second ss.ExpiresIn(ttl) g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl))) diff --git a/pkg/clock/clock.go b/pkg/clock/clock.go deleted file mode 100644 index 887bf0aa..00000000 --- a/pkg/clock/clock.go +++ /dev/null @@ -1,157 +0,0 @@ -package clock - -import ( - "errors" - "sync" - "time" - - clockapi "github.com/benbjohnson/clock" -) - -var ( - globalClock = clockapi.New() - mu sync.Mutex -) - -// Set the global clock to a clockapi.Mock with the given time.Time -func Set(t time.Time) { - mu.Lock() - defer mu.Unlock() - mock, ok := globalClock.(*clockapi.Mock) - if !ok { - mock = clockapi.NewMock() - } - mock.Set(t) - globalClock = mock -} - -// Add moves the mocked global clock forward the given duration. It will error -// if the global clock is not mocked. -func Add(d time.Duration) error { - mu.Lock() - defer mu.Unlock() - mock, ok := globalClock.(*clockapi.Mock) - if !ok { - return errors.New("time not mocked") - } - mock.Add(d) - return nil -} - -// Reset sets the global clock to a pure time implementation. Returns any -// existing Mock if set in case lingering time operations are attached to it. -func Reset() *clockapi.Mock { - mu.Lock() - defer mu.Unlock() - existing := globalClock - globalClock = clockapi.New() - - mock, ok := existing.(*clockapi.Mock) - if !ok { - return nil - } - return mock -} - -// Clock is a non-package level wrapper around time that supports stubbing. -// It will use its localized stubs (allowing for parallelized unit tests -// where package level stubbing would cause issues). It falls back to any -// package level time stubs for non-parallel, cross-package integration -// testing scenarios. -// -// If nothing is stubbed, it defaults to default time behavior in the time -// package. -type Clock struct { - mock *clockapi.Mock -} - -// Set sets the Clock to a clock.Mock at the given time.Time -func (c *Clock) Set(t time.Time) { - if c.mock == nil { - c.mock = clockapi.NewMock() - } - c.mock.Set(t) -} - -// Add moves clock forward time.Duration if it is mocked. It will error -// if the clock is not mocked. -func (c *Clock) Add(d time.Duration) error { - if c.mock == nil { - return errors.New("clock not mocked") - } - c.mock.Add(d) - return nil -} - -// Reset removes local clock.Mock. Returns any existing Mock if set in case -// lingering time operations are attached to it. -func (c *Clock) Reset() *clockapi.Mock { - existing := c.mock - c.mock = nil - return existing -} - -func (c *Clock) After(d time.Duration) <-chan time.Time { - m := c.mock - if m == nil { - return globalClock.After(d) - } - return m.After(d) -} - -func (c *Clock) AfterFunc(d time.Duration, f func()) *clockapi.Timer { - m := c.mock - if m == nil { - return globalClock.AfterFunc(d, f) - } - return m.AfterFunc(d, f) -} - -func (c *Clock) Now() time.Time { - m := c.mock - if m == nil { - return globalClock.Now() - } - return m.Now() -} - -func (c *Clock) Since(t time.Time) time.Duration { - m := c.mock - if m == nil { - return globalClock.Since(t) - } - return m.Since(t) -} - -func (c *Clock) Sleep(d time.Duration) { - m := c.mock - if m == nil { - globalClock.Sleep(d) - return - } - m.Sleep(d) -} - -func (c *Clock) Tick(d time.Duration) <-chan time.Time { - m := c.mock - if m == nil { - return globalClock.Tick(d) - } - return m.Tick(d) -} - -func (c *Clock) Ticker(d time.Duration) *clockapi.Ticker { - m := c.mock - if m == nil { - return globalClock.Ticker(d) - } - return m.Ticker(d) -} - -func (c *Clock) Timer(d time.Duration) *clockapi.Timer { - m := c.mock - if m == nil { - return globalClock.Timer(d) - } - return m.Timer(d) -} diff --git a/pkg/clock/clock_suite_test.go b/pkg/clock/clock_suite_test.go deleted file mode 100644 index 99d39f24..00000000 --- a/pkg/clock/clock_suite_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package clock_test - -import ( - "testing" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -func TestClockSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) - - RegisterFailHandler(Fail) - RunSpecs(t, "Clock") -} diff --git a/pkg/clock/clock_test.go b/pkg/clock/clock_test.go deleted file mode 100644 index e1b6d440..00000000 --- a/pkg/clock/clock_test.go +++ /dev/null @@ -1,380 +0,0 @@ -package clock_test - -import ( - "sync" - "sync/atomic" - "time" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -const ( - testGlobalEpoch = 1000000000 - testLocalEpoch = 1234567890 -) - -var _ = Describe("Clock suite", func() { - var testClock = clock.Clock{} - - AfterEach(func() { - clock.Reset() - testClock.Reset() - }) - - Context("Global time not overridden", func() { - It("errors when trying to Add", func() { - err := clock.Add(123 * time.Hour) - Expect(err).To(HaveOccurred()) - }) - - Context("Clock not mocked via Set", func() { - const ( - outsideTolerance = int32(0) - withinTolerance = int32(1) - ) - - It("uses time.After for After", func() { - var tolerance int32 - go func() { - time.Sleep(10 * time.Millisecond) - atomic.StoreInt32(&tolerance, withinTolerance) - }() - go func() { - time.Sleep(30 * time.Millisecond) - atomic.StoreInt32(&tolerance, outsideTolerance) - }() - - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - - <-testClock.After(20 * time.Millisecond) - Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) - - <-testClock.After(20 * time.Millisecond) - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - }) - - It("uses time.AfterFunc for AfterFunc", func() { - var tolerance int32 - go func() { - time.Sleep(10 * time.Millisecond) - atomic.StoreInt32(&tolerance, withinTolerance) - }() - go func() { - time.Sleep(30 * time.Millisecond) - atomic.StoreInt32(&tolerance, outsideTolerance) - }() - - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - - var wg sync.WaitGroup - wg.Add(1) - testClock.AfterFunc(20*time.Millisecond, func() { - wg.Done() - }) - wg.Wait() - Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) - - wg.Add(1) - testClock.AfterFunc(20*time.Millisecond, func() { - wg.Done() - }) - wg.Wait() - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - }) - - It("uses time.Now for Now", func() { - a := time.Now() - b := testClock.Now() - Expect(b.Sub(a).Round(10 * time.Millisecond)).To(Equal(0 * time.Millisecond)) - }) - - It("uses time.Since for Since", func() { - past := time.Now().Add(-60 * time.Second) - Expect(time.Since(past).Round(10 * time.Millisecond)). - To(Equal(60 * time.Second)) - }) - - It("uses time.Sleep for Sleep", func() { - var tolerance int32 - go func() { - time.Sleep(10 * time.Millisecond) - atomic.StoreInt32(&tolerance, withinTolerance) - }() - go func() { - time.Sleep(30 * time.Millisecond) - atomic.StoreInt32(&tolerance, outsideTolerance) - }() - - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - - testClock.Sleep(20 * time.Millisecond) - Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) - - testClock.Sleep(20 * time.Millisecond) - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - }) - - It("uses time.Tick for Tick", func() { - var tolerance int32 - go func() { - time.Sleep(10 * time.Millisecond) - atomic.StoreInt32(&tolerance, withinTolerance) - }() - go func() { - time.Sleep(50 * time.Millisecond) - atomic.StoreInt32(&tolerance, outsideTolerance) - }() - - ch := testClock.Tick(20 * time.Millisecond) - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - <-ch - Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) - <-ch - Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) - <-ch - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - }) - - It("uses time.Ticker for Ticker", func() { - var tolerance int32 - go func() { - time.Sleep(10 * time.Millisecond) - atomic.StoreInt32(&tolerance, withinTolerance) - }() - go func() { - time.Sleep(50 * time.Millisecond) - atomic.StoreInt32(&tolerance, outsideTolerance) - }() - - ticker := testClock.Ticker(20 * time.Millisecond) - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - <-ticker.C - Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) - <-ticker.C - Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance)) - <-ticker.C - Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance)) - }) - - It("errors if Add is used", func() { - err := testClock.Add(100 * time.Second) - Expect(err).To(HaveOccurred()) - }) - }) - - Context("Clock mocked via Set", func() { - var now = time.Unix(testLocalEpoch, 0) - - BeforeEach(func() { - testClock.Set(now) - }) - - It("mocks After", func() { - var after int32 - ready := make(chan struct{}) - ch := testClock.After(10 * time.Second) - go func(ch <-chan time.Time) { - close(ready) - <-ch - atomic.StoreInt32(&after, 1) - }(ch) - <-ready - - err := testClock.Add(9 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) - - err = testClock.Add(1 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(1))) - }) - - It("mocks AfterFunc", func() { - var after int32 - testClock.AfterFunc(10*time.Second, func() { - atomic.StoreInt32(&after, 1) - }) - - err := testClock.Add(9 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) - - err = testClock.Add(1 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(1))) - }) - - It("mocks AfterFunc with a stopped timer", func() { - var after int32 - timer := testClock.AfterFunc(10*time.Second, func() { - atomic.StoreInt32(&after, 1) - }) - timer.Stop() - - err := testClock.Add(11 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) - }) - - It("mocks Now", func() { - Expect(testClock.Now()).To(Equal(now)) - err := testClock.Add(123 * time.Hour) - Expect(err).ToNot(HaveOccurred()) - Expect(testClock.Now()).To(Equal(now.Add(123 * time.Hour))) - }) - - It("mocks Since", func() { - Expect(testClock.Since(time.Unix(testLocalEpoch-100, 0))). - To(Equal(100 * time.Second)) - }) - - It("mocks Sleep", func() { - var after int32 - ready := make(chan struct{}) - go func() { - close(ready) - testClock.Sleep(10 * time.Second) - atomic.StoreInt32(&after, 1) - }() - <-ready - - err := testClock.Add(9 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) - - err = testClock.Add(1 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(1))) - }) - - It("mocks Tick", func() { - var ticks int32 - ready := make(chan struct{}) - go func() { - close(ready) - tick := testClock.Tick(10 * time.Second) - for ticks < 5 { - <-tick - atomic.AddInt32(&ticks, 1) - } - }() - <-ready - - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0))) - - err := testClock.Add(9 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0))) - - err = testClock.Add(1 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(1))) - - err = testClock.Add(30 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(4))) - - err = testClock.Add(10 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(5))) - }) - - It("mocks Ticker", func() { - var ticks int32 - ready := make(chan struct{}) - go func() { - ticker := testClock.Ticker(10 * time.Second) - close(ready) - for ticks < 5 { - <-ticker.C - atomic.AddInt32(&ticks, 1) - } - }() - <-ready - - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0))) - - err := testClock.Add(9 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0))) - - err = testClock.Add(1 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(1))) - - err = testClock.Add(30 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(4))) - - err = testClock.Add(10 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(5))) - }) - - It("mocks Timer", func() { - var after int32 - ready := make(chan struct{}) - go func() { - timer := testClock.Timer(10 * time.Second) - close(ready) - <-timer.C - atomic.AddInt32(&after, 1) - }() - <-ready - - err := testClock.Add(9 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(0))) - - err = testClock.Add(1 * time.Second) - Expect(err).ToNot(HaveOccurred()) - Expect(atomic.LoadInt32(&after)).To(Equal(int32(1))) - }) - }) - }) - - Context("Global time overridden", func() { - var ( - globalNow = time.Unix(testGlobalEpoch, 0) - localNow = time.Unix(testLocalEpoch, 0) - ) - - BeforeEach(func() { - clock.Set(globalNow) - }) - - Context("Clock not mocked via Set", func() { - It("uses globally mocked Now", func() { - Expect(testClock.Now()).To(Equal(globalNow)) - err := clock.Add(123 * time.Hour) - Expect(err).ToNot(HaveOccurred()) - Expect(testClock.Now()).To(Equal(globalNow.Add(123 * time.Hour))) - }) - - It("errors when Add is called on the local Clock", func() { - err := testClock.Add(100 * time.Hour) - Expect(err).To(HaveOccurred()) - }) - }) - - Context("Clock is mocked via Set", func() { - BeforeEach(func() { - testClock.Set(localNow) - }) - - It("uses the local mock and ignores the global", func() { - Expect(testClock.Now()).To(Equal(localNow)) - - err := clock.Add(456 * time.Hour) - Expect(err).ToNot(HaveOccurred()) - - err = testClock.Add(123 * time.Hour) - Expect(err).ToNot(HaveOccurred()) - - Expect(testClock.Now()).To(Equal(localNow.Add(123 * time.Hour))) - }) - }) - }) -}) diff --git a/pkg/cookies/csrf.go b/pkg/cookies/csrf.go index 3b8efaf3..939578a2 100644 --- a/pkg/cookies/csrf.go +++ b/pkg/cookies/csrf.go @@ -10,7 +10,6 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/vmihailenco/msgpack/v5" ) @@ -47,7 +46,7 @@ type csrf struct { CodeVerifier string `msgpack:"cv,omitempty"` cookieOpts *options.Cookie - time clock.Clock + clock func() time.Time } // csrtStateTrim will indicate the length of the state trimmed for the name of the csrf cookie @@ -70,6 +69,7 @@ func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) { CodeVerifier: codeVerifier, cookieOpts: opts, + clock: time.Now, }, nil } @@ -187,7 +187,7 @@ func ClearExtraCsrfCookies(opts *options.Cookie, rw http.ResponseWriter, req *ht // delete the X oldest cookies slices.SortStableFunc(decodedCookies, func(a, b *csrf) int { - return a.time.Now().Compare(b.time.Now()) + return a.clock().Compare(b.clock()) }) for i := 0; i < len(decodedCookies)-opts.CSRFPerRequestLimit; i++ { @@ -223,7 +223,7 @@ func (c *csrf) encodeCookie() (string, error) { if err != nil { return "", fmt.Errorf("error getting cookie secret: %v", err) } - return encryption.SignedValue(secret, c.cookieName(), encrypted, c.time.Now()) + return encryption.SignedValue(secret, c.cookieName(), encrypted, c.clock()) } // decodeCSRFCookie validates the signature then decrypts and decodes a CSRF @@ -249,10 +249,10 @@ func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error) // unmarshalCSRF unmarshals decrypted data into a CSRF struct func unmarshalCSRF(decrypted []byte, opts *options.Cookie, csrfTime time.Time) (*csrf, error) { - clock := clock.Clock{} - clock.Set(csrfTime) - - csrf := &csrf{cookieOpts: opts, time: clock} + csrf := &csrf{ + cookieOpts: opts, + clock: func() time.Time { return csrfTime }, + } if err := msgpack.Unmarshal(decrypted, csrf); err != nil { return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err) } diff --git a/pkg/cookies/csrf_per_request_test.go b/pkg/cookies/csrf_per_request_test.go index 9b7d4e59..6a17013b 100644 --- a/pkg/cookies/csrf_per_request_test.go +++ b/pkg/cookies/csrf_per_request_test.go @@ -128,7 +128,7 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() { testNow := time.Unix(nowEpoch, 0) BeforeEach(func() { - privateCSRF.time.Set(testNow) + privateCSRF.clock = func() time.Time { return testNow } req = &http.Request{ Method: http.MethodGet, @@ -144,7 +144,7 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() { }) AfterEach(func() { - privateCSRF.time.Reset() + privateCSRF.clock = time.Now }) Context("SetCookie", func() { @@ -200,17 +200,17 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() { publicCSRF1, err := NewCSRF(cookieOpts, "verifier") Expect(err).ToNot(HaveOccurred()) privateCSRF1 := publicCSRF1.(*csrf) - privateCSRF1.time.Set(testNow) + privateCSRF1.clock = func() time.Time { return testNow } publicCSRF2, err := NewCSRF(cookieOpts, "verifier") Expect(err).ToNot(HaveOccurred()) privateCSRF2 := publicCSRF2.(*csrf) - privateCSRF2.time.Set(testNow.Add(time.Minute)) + privateCSRF2.clock = func() time.Time { return testNow.Add(time.Minute) } publicCSRF3, err := NewCSRF(cookieOpts, "verifier") Expect(err).ToNot(HaveOccurred()) privateCSRF3 := publicCSRF3.(*csrf) - privateCSRF3.time.Set(testNow.Add(time.Minute * 2)) + privateCSRF3.clock = func() time.Time { return testNow.Add(time.Minute * 2) } cookies := []string{} for _, csrf := range []*csrf{privateCSRF1, privateCSRF2, privateCSRF3} { diff --git a/pkg/cookies/csrf_test.go b/pkg/cookies/csrf_test.go index 37527bd0..085b91df 100644 --- a/pkg/cookies/csrf_test.go +++ b/pkg/cookies/csrf_test.go @@ -130,7 +130,7 @@ var _ = Describe("CSRF Cookie Tests", func() { testNow := time.Unix(nowEpoch, 0) BeforeEach(func() { - privateCSRF.time.Set(testNow) + privateCSRF.clock = func() time.Time { return testNow } req = &http.Request{ Method: http.MethodGet, @@ -146,7 +146,7 @@ var _ = Describe("CSRF Cookie Tests", func() { }) AfterEach(func() { - privateCSRF.time.Reset() + privateCSRF.clock = time.Now }) Context("SetCookie", func() { @@ -173,7 +173,7 @@ var _ = Describe("CSRF Cookie Tests", func() { Context("LoadCSRFCookie", func() { BeforeEach(func() { // we need to reset the time to ensure the cookie is valid - privateCSRF.time.Reset() + privateCSRF.clock = time.Now }) It("should return error when no cookie is set", func() { diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 904c2028..d8e78f2f 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -11,7 +11,6 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" "github.com/oauth2-proxy/oauth2-proxy/v7/providers" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -95,6 +94,7 @@ var _ = Describe("Stored Session Suite", func() { now := time.Now() createdPast := now.Add(-5 * time.Minute) createdFuture := now.Add(5 * time.Minute) + clock := func() time.Time { return now } var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { @@ -120,6 +120,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Clock: clock, }, nil case "_oauth2_proxy=InvalidNoRefreshSession": return &sessionsapi.SessionState{ @@ -127,24 +128,28 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Clock: clock, }, nil case "_oauth2_proxy=ExpiredNoRefreshSession": return &sessionsapi.SessionState{ RefreshToken: noRefresh, CreatedAt: &createdPast, ExpiresOn: &createdPast, + Clock: clock, }, nil case "_oauth2_proxy=RefreshSession": return &sessionsapi.SessionState{ RefreshToken: refresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Clock: clock, }, nil case "_oauth2_proxy=RefreshError": return &sessionsapi.SessionState{ RefreshToken: "RefreshError", CreatedAt: &createdPast, ExpiresOn: &createdFuture, + Clock: clock, }, nil case "_oauth2_proxy=NonExistent": return nil, fmt.Errorf("invalid cookie") @@ -154,14 +159,6 @@ var _ = Describe("Stored Session Suite", func() { }, } - BeforeEach(func() { - clock.Set(now) - }) - - AfterEach(func() { - clock.Reset() - }) - type storedSessionLoaderTableInput struct { requestHeaders http.Header existingSession *sessionsapi.SessionState @@ -200,7 +197,15 @@ var _ = Describe("Stored Session Suite", func() { })) handler.ServeHTTP(rw, req) - Expect(gotSession).To(Equal(in.expectedSession)) + // Compare, ignoring testing Clock. + if in.expectedSession == nil { + Expect(gotSession).To(BeNil()) + return + } + Expect(gotSession).ToNot(BeNil()) + got := *gotSession + got.Clock = nil + Expect(&got).To(Equal(in.expectedSession)) }, Entry("with no cookie", storedSessionLoaderTableInput{ requestHeaders: http.Header{}, diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index f678d13d..0fbe7abd 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -17,8 +17,9 @@ func TestRefresh(t *testing.T) { now := time.Unix(1234567890, 10) expires := time.Unix(1234567890, 0) - ss := &sessions.SessionState{} - ss.Clock.Set(now) + ss := &sessions.SessionState{ + Clock: func() time.Time { return now }, + } ss.SetExpiresOn(expires) refreshed, err := p.RefreshSession(context.Background(), ss)