test: replace mock pkg/clock with narrowly targeted stub clocks. (#3238)

The package under pkg/clock is github.com/benbjohnson/clock, which is
archived. It's also way more complex than is what is actually needed
here, so we can entirely remove the dependency and remove the helper
package.

Fixes #2840.

Signed-off-by: David Symonds <dsymonds@gmail.com>
This commit is contained in:
David Symonds 2025-10-28 20:05:02 +11:00 committed by GitHub
parent 8f687e4d0c
commit 110d51d1d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 50 additions and 594 deletions

View File

@ -10,6 +10,7 @@
- [#3228](https://github.com/oauth2-proxy/oauth2-proxy/pull/3228) fix: use GetSecret() in ticket.go makeCookie to respect cookie-secret-file (@stagswtf)
- [#3244](https://github.com/oauth2-proxy/oauth2-proxy/pull/3244) chore(deps): upgrade to latest go1.25.3 (@tuunit)
- [#3238](https://github.com/oauth2-proxy/oauth2-proxy/pull/3238) chore: Replace pkg/clock with narrowly targeted stub clocks (@dsymonds)
# V7.12.0

1
go.mod
View File

@ -7,7 +7,6 @@ require (
github.com/Bose/minisentinel v0.0.0-20200130220412-917c5a9223bb
github.com/a8m/envsubst v1.4.3
github.com/alicebob/miniredis/v2 v2.35.0
github.com/benbjohnson/clock v1.3.5
github.com/bitly/go-simplejson v0.5.1
github.com/bsm/redislock v0.9.4
github.com/coreos/go-oidc/v3 v3.14.1

2
go.sum
View File

@ -14,8 +14,6 @@ github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGn
github.com/alicebob/miniredis/v2 v2.11.1/go.mod h1:UA48pmi7aSazcGAvcdKcBB49z521IC9VjTTRz2nIaJE=
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o=
github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bitly/go-simplejson v0.5.1 h1:xgwPbetQScXt1gh9BmoJ6j9JMr3TElvuIyjR8pgdoow=

View File

@ -7,7 +7,6 @@ import (
"io"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/pierrec/lz4/v4"
"github.com/vmihailenco/msgpack/v5"
@ -30,8 +29,15 @@ type SessionState struct {
PreferredUsername string `msgpack:"pu,omitempty"`
// Internal helpers, not serialized
Clock clock.Clock `msgpack:"-"`
Lock Lock `msgpack:"-"`
Clock func() time.Time `msgpack:"-"` // override for time.Now, for testing
Lock Lock `msgpack:"-"`
}
func (s *SessionState) now() time.Time {
if s.Clock != nil {
return s.Clock()
}
return time.Now()
}
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
@ -64,7 +70,7 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
// CreatedAtNow sets a SessionState's CreatedAt to now
func (s *SessionState) CreatedAtNow() {
now := s.Clock.Now()
now := s.now()
s.CreatedAt = &now
}
@ -85,7 +91,7 @@ func (s *SessionState) ExpiresIn(d time.Duration) {
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) {
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.now()) {
return true
}
return false
@ -94,7 +100,7 @@ func (s *SessionState) IsExpired() bool {
// Age returns the age of a session
func (s *SessionState) Age() time.Duration {
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt)
return s.now().Truncate(time.Second).Sub(*s.CreatedAt)
}
return 0
}

View File

@ -22,7 +22,7 @@ func TestCreatedAtNow(t *testing.T) {
ss := &SessionState{}
now := time.Unix(1234567890, 0)
ss.Clock.Set(now)
ss.Clock = func() time.Time { return now }
ss.CreatedAtNow()
g.Expect(*ss.CreatedAt).To(Equal(now))
@ -33,9 +33,9 @@ func TestExpiresIn(t *testing.T) {
ss := &SessionState{}
now := time.Unix(1234567890, 0)
ss.Clock.Set(now)
ss.Clock = func() time.Time { return now }
ttl := time.Duration(743) * time.Second
ttl := 743 * time.Second
ss.ExpiresIn(ttl)
g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl)))

View File

@ -1,157 +0,0 @@
package clock
import (
"errors"
"sync"
"time"
clockapi "github.com/benbjohnson/clock"
)
var (
globalClock = clockapi.New()
mu sync.Mutex
)
// Set the global clock to a clockapi.Mock with the given time.Time
func Set(t time.Time) {
mu.Lock()
defer mu.Unlock()
mock, ok := globalClock.(*clockapi.Mock)
if !ok {
mock = clockapi.NewMock()
}
mock.Set(t)
globalClock = mock
}
// Add moves the mocked global clock forward the given duration. It will error
// if the global clock is not mocked.
func Add(d time.Duration) error {
mu.Lock()
defer mu.Unlock()
mock, ok := globalClock.(*clockapi.Mock)
if !ok {
return errors.New("time not mocked")
}
mock.Add(d)
return nil
}
// Reset sets the global clock to a pure time implementation. Returns any
// existing Mock if set in case lingering time operations are attached to it.
func Reset() *clockapi.Mock {
mu.Lock()
defer mu.Unlock()
existing := globalClock
globalClock = clockapi.New()
mock, ok := existing.(*clockapi.Mock)
if !ok {
return nil
}
return mock
}
// Clock is a non-package level wrapper around time that supports stubbing.
// It will use its localized stubs (allowing for parallelized unit tests
// where package level stubbing would cause issues). It falls back to any
// package level time stubs for non-parallel, cross-package integration
// testing scenarios.
//
// If nothing is stubbed, it defaults to default time behavior in the time
// package.
type Clock struct {
mock *clockapi.Mock
}
// Set sets the Clock to a clock.Mock at the given time.Time
func (c *Clock) Set(t time.Time) {
if c.mock == nil {
c.mock = clockapi.NewMock()
}
c.mock.Set(t)
}
// Add moves clock forward time.Duration if it is mocked. It will error
// if the clock is not mocked.
func (c *Clock) Add(d time.Duration) error {
if c.mock == nil {
return errors.New("clock not mocked")
}
c.mock.Add(d)
return nil
}
// Reset removes local clock.Mock. Returns any existing Mock if set in case
// lingering time operations are attached to it.
func (c *Clock) Reset() *clockapi.Mock {
existing := c.mock
c.mock = nil
return existing
}
func (c *Clock) After(d time.Duration) <-chan time.Time {
m := c.mock
if m == nil {
return globalClock.After(d)
}
return m.After(d)
}
func (c *Clock) AfterFunc(d time.Duration, f func()) *clockapi.Timer {
m := c.mock
if m == nil {
return globalClock.AfterFunc(d, f)
}
return m.AfterFunc(d, f)
}
func (c *Clock) Now() time.Time {
m := c.mock
if m == nil {
return globalClock.Now()
}
return m.Now()
}
func (c *Clock) Since(t time.Time) time.Duration {
m := c.mock
if m == nil {
return globalClock.Since(t)
}
return m.Since(t)
}
func (c *Clock) Sleep(d time.Duration) {
m := c.mock
if m == nil {
globalClock.Sleep(d)
return
}
m.Sleep(d)
}
func (c *Clock) Tick(d time.Duration) <-chan time.Time {
m := c.mock
if m == nil {
return globalClock.Tick(d)
}
return m.Tick(d)
}
func (c *Clock) Ticker(d time.Duration) *clockapi.Ticker {
m := c.mock
if m == nil {
return globalClock.Ticker(d)
}
return m.Ticker(d)
}
func (c *Clock) Timer(d time.Duration) *clockapi.Timer {
m := c.mock
if m == nil {
return globalClock.Timer(d)
}
return m.Timer(d)
}

View File

@ -1,17 +0,0 @@
package clock_test
import (
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestClockSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
logger.SetErrOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Clock")
}

View File

@ -1,380 +0,0 @@
package clock_test
import (
"sync"
"sync/atomic"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const (
testGlobalEpoch = 1000000000
testLocalEpoch = 1234567890
)
var _ = Describe("Clock suite", func() {
var testClock = clock.Clock{}
AfterEach(func() {
clock.Reset()
testClock.Reset()
})
Context("Global time not overridden", func() {
It("errors when trying to Add", func() {
err := clock.Add(123 * time.Hour)
Expect(err).To(HaveOccurred())
})
Context("Clock not mocked via Set", func() {
const (
outsideTolerance = int32(0)
withinTolerance = int32(1)
)
It("uses time.After for After", func() {
var tolerance int32
go func() {
time.Sleep(10 * time.Millisecond)
atomic.StoreInt32(&tolerance, withinTolerance)
}()
go func() {
time.Sleep(30 * time.Millisecond)
atomic.StoreInt32(&tolerance, outsideTolerance)
}()
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
<-testClock.After(20 * time.Millisecond)
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
<-testClock.After(20 * time.Millisecond)
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
})
It("uses time.AfterFunc for AfterFunc", func() {
var tolerance int32
go func() {
time.Sleep(10 * time.Millisecond)
atomic.StoreInt32(&tolerance, withinTolerance)
}()
go func() {
time.Sleep(30 * time.Millisecond)
atomic.StoreInt32(&tolerance, outsideTolerance)
}()
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
var wg sync.WaitGroup
wg.Add(1)
testClock.AfterFunc(20*time.Millisecond, func() {
wg.Done()
})
wg.Wait()
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
wg.Add(1)
testClock.AfterFunc(20*time.Millisecond, func() {
wg.Done()
})
wg.Wait()
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
})
It("uses time.Now for Now", func() {
a := time.Now()
b := testClock.Now()
Expect(b.Sub(a).Round(10 * time.Millisecond)).To(Equal(0 * time.Millisecond))
})
It("uses time.Since for Since", func() {
past := time.Now().Add(-60 * time.Second)
Expect(time.Since(past).Round(10 * time.Millisecond)).
To(Equal(60 * time.Second))
})
It("uses time.Sleep for Sleep", func() {
var tolerance int32
go func() {
time.Sleep(10 * time.Millisecond)
atomic.StoreInt32(&tolerance, withinTolerance)
}()
go func() {
time.Sleep(30 * time.Millisecond)
atomic.StoreInt32(&tolerance, outsideTolerance)
}()
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
testClock.Sleep(20 * time.Millisecond)
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
testClock.Sleep(20 * time.Millisecond)
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
})
It("uses time.Tick for Tick", func() {
var tolerance int32
go func() {
time.Sleep(10 * time.Millisecond)
atomic.StoreInt32(&tolerance, withinTolerance)
}()
go func() {
time.Sleep(50 * time.Millisecond)
atomic.StoreInt32(&tolerance, outsideTolerance)
}()
ch := testClock.Tick(20 * time.Millisecond)
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
<-ch
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
<-ch
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
<-ch
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
})
It("uses time.Ticker for Ticker", func() {
var tolerance int32
go func() {
time.Sleep(10 * time.Millisecond)
atomic.StoreInt32(&tolerance, withinTolerance)
}()
go func() {
time.Sleep(50 * time.Millisecond)
atomic.StoreInt32(&tolerance, outsideTolerance)
}()
ticker := testClock.Ticker(20 * time.Millisecond)
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
<-ticker.C
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
<-ticker.C
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
<-ticker.C
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
})
It("errors if Add is used", func() {
err := testClock.Add(100 * time.Second)
Expect(err).To(HaveOccurred())
})
})
Context("Clock mocked via Set", func() {
var now = time.Unix(testLocalEpoch, 0)
BeforeEach(func() {
testClock.Set(now)
})
It("mocks After", func() {
var after int32
ready := make(chan struct{})
ch := testClock.After(10 * time.Second)
go func(ch <-chan time.Time) {
close(ready)
<-ch
atomic.StoreInt32(&after, 1)
}(ch)
<-ready
err := testClock.Add(9 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
err = testClock.Add(1 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(1)))
})
It("mocks AfterFunc", func() {
var after int32
testClock.AfterFunc(10*time.Second, func() {
atomic.StoreInt32(&after, 1)
})
err := testClock.Add(9 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
err = testClock.Add(1 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(1)))
})
It("mocks AfterFunc with a stopped timer", func() {
var after int32
timer := testClock.AfterFunc(10*time.Second, func() {
atomic.StoreInt32(&after, 1)
})
timer.Stop()
err := testClock.Add(11 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
})
It("mocks Now", func() {
Expect(testClock.Now()).To(Equal(now))
err := testClock.Add(123 * time.Hour)
Expect(err).ToNot(HaveOccurred())
Expect(testClock.Now()).To(Equal(now.Add(123 * time.Hour)))
})
It("mocks Since", func() {
Expect(testClock.Since(time.Unix(testLocalEpoch-100, 0))).
To(Equal(100 * time.Second))
})
It("mocks Sleep", func() {
var after int32
ready := make(chan struct{})
go func() {
close(ready)
testClock.Sleep(10 * time.Second)
atomic.StoreInt32(&after, 1)
}()
<-ready
err := testClock.Add(9 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
err = testClock.Add(1 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(1)))
})
It("mocks Tick", func() {
var ticks int32
ready := make(chan struct{})
go func() {
close(ready)
tick := testClock.Tick(10 * time.Second)
for ticks < 5 {
<-tick
atomic.AddInt32(&ticks, 1)
}
}()
<-ready
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0)))
err := testClock.Add(9 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0)))
err = testClock.Add(1 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(1)))
err = testClock.Add(30 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(4)))
err = testClock.Add(10 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(5)))
})
It("mocks Ticker", func() {
var ticks int32
ready := make(chan struct{})
go func() {
ticker := testClock.Ticker(10 * time.Second)
close(ready)
for ticks < 5 {
<-ticker.C
atomic.AddInt32(&ticks, 1)
}
}()
<-ready
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0)))
err := testClock.Add(9 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0)))
err = testClock.Add(1 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(1)))
err = testClock.Add(30 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(4)))
err = testClock.Add(10 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(5)))
})
It("mocks Timer", func() {
var after int32
ready := make(chan struct{})
go func() {
timer := testClock.Timer(10 * time.Second)
close(ready)
<-timer.C
atomic.AddInt32(&after, 1)
}()
<-ready
err := testClock.Add(9 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
err = testClock.Add(1 * time.Second)
Expect(err).ToNot(HaveOccurred())
Expect(atomic.LoadInt32(&after)).To(Equal(int32(1)))
})
})
})
Context("Global time overridden", func() {
var (
globalNow = time.Unix(testGlobalEpoch, 0)
localNow = time.Unix(testLocalEpoch, 0)
)
BeforeEach(func() {
clock.Set(globalNow)
})
Context("Clock not mocked via Set", func() {
It("uses globally mocked Now", func() {
Expect(testClock.Now()).To(Equal(globalNow))
err := clock.Add(123 * time.Hour)
Expect(err).ToNot(HaveOccurred())
Expect(testClock.Now()).To(Equal(globalNow.Add(123 * time.Hour)))
})
It("errors when Add is called on the local Clock", func() {
err := testClock.Add(100 * time.Hour)
Expect(err).To(HaveOccurred())
})
})
Context("Clock is mocked via Set", func() {
BeforeEach(func() {
testClock.Set(localNow)
})
It("uses the local mock and ignores the global", func() {
Expect(testClock.Now()).To(Equal(localNow))
err := clock.Add(456 * time.Hour)
Expect(err).ToNot(HaveOccurred())
err = testClock.Add(123 * time.Hour)
Expect(err).ToNot(HaveOccurred())
Expect(testClock.Now()).To(Equal(localNow.Add(123 * time.Hour)))
})
})
})
})

View File

@ -10,7 +10,6 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/vmihailenco/msgpack/v5"
)
@ -47,7 +46,7 @@ type csrf struct {
CodeVerifier string `msgpack:"cv,omitempty"`
cookieOpts *options.Cookie
time clock.Clock
clock func() time.Time
}
// csrtStateTrim will indicate the length of the state trimmed for the name of the csrf cookie
@ -70,6 +69,7 @@ func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) {
CodeVerifier: codeVerifier,
cookieOpts: opts,
clock: time.Now,
}, nil
}
@ -187,7 +187,7 @@ func ClearExtraCsrfCookies(opts *options.Cookie, rw http.ResponseWriter, req *ht
// delete the X oldest cookies
slices.SortStableFunc(decodedCookies, func(a, b *csrf) int {
return a.time.Now().Compare(b.time.Now())
return a.clock().Compare(b.clock())
})
for i := 0; i < len(decodedCookies)-opts.CSRFPerRequestLimit; i++ {
@ -223,7 +223,7 @@ func (c *csrf) encodeCookie() (string, error) {
if err != nil {
return "", fmt.Errorf("error getting cookie secret: %v", err)
}
return encryption.SignedValue(secret, c.cookieName(), encrypted, c.time.Now())
return encryption.SignedValue(secret, c.cookieName(), encrypted, c.clock())
}
// decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
@ -249,10 +249,10 @@ func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error)
// unmarshalCSRF unmarshals decrypted data into a CSRF struct
func unmarshalCSRF(decrypted []byte, opts *options.Cookie, csrfTime time.Time) (*csrf, error) {
clock := clock.Clock{}
clock.Set(csrfTime)
csrf := &csrf{cookieOpts: opts, time: clock}
csrf := &csrf{
cookieOpts: opts,
clock: func() time.Time { return csrfTime },
}
if err := msgpack.Unmarshal(decrypted, csrf); err != nil {
return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err)
}

View File

@ -128,7 +128,7 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() {
testNow := time.Unix(nowEpoch, 0)
BeforeEach(func() {
privateCSRF.time.Set(testNow)
privateCSRF.clock = func() time.Time { return testNow }
req = &http.Request{
Method: http.MethodGet,
@ -144,7 +144,7 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() {
})
AfterEach(func() {
privateCSRF.time.Reset()
privateCSRF.clock = time.Now
})
Context("SetCookie", func() {
@ -200,17 +200,17 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() {
publicCSRF1, err := NewCSRF(cookieOpts, "verifier")
Expect(err).ToNot(HaveOccurred())
privateCSRF1 := publicCSRF1.(*csrf)
privateCSRF1.time.Set(testNow)
privateCSRF1.clock = func() time.Time { return testNow }
publicCSRF2, err := NewCSRF(cookieOpts, "verifier")
Expect(err).ToNot(HaveOccurred())
privateCSRF2 := publicCSRF2.(*csrf)
privateCSRF2.time.Set(testNow.Add(time.Minute))
privateCSRF2.clock = func() time.Time { return testNow.Add(time.Minute) }
publicCSRF3, err := NewCSRF(cookieOpts, "verifier")
Expect(err).ToNot(HaveOccurred())
privateCSRF3 := publicCSRF3.(*csrf)
privateCSRF3.time.Set(testNow.Add(time.Minute * 2))
privateCSRF3.clock = func() time.Time { return testNow.Add(time.Minute * 2) }
cookies := []string{}
for _, csrf := range []*csrf{privateCSRF1, privateCSRF2, privateCSRF3} {

View File

@ -130,7 +130,7 @@ var _ = Describe("CSRF Cookie Tests", func() {
testNow := time.Unix(nowEpoch, 0)
BeforeEach(func() {
privateCSRF.time.Set(testNow)
privateCSRF.clock = func() time.Time { return testNow }
req = &http.Request{
Method: http.MethodGet,
@ -146,7 +146,7 @@ var _ = Describe("CSRF Cookie Tests", func() {
})
AfterEach(func() {
privateCSRF.time.Reset()
privateCSRF.clock = time.Now
})
Context("SetCookie", func() {
@ -173,7 +173,7 @@ var _ = Describe("CSRF Cookie Tests", func() {
Context("LoadCSRFCookie", func() {
BeforeEach(func() {
// we need to reset the time to ensure the cookie is valid
privateCSRF.time.Reset()
privateCSRF.clock = time.Now
})
It("should return error when no cookie is set", func() {

View File

@ -11,7 +11,6 @@ import (
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -95,6 +94,7 @@ var _ = Describe("Stored Session Suite", func() {
now := time.Now()
createdPast := now.Add(-5 * time.Minute)
createdFuture := now.Add(5 * time.Minute)
clock := func() time.Time { return now }
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
switch ss.RefreshToken {
@ -120,6 +120,7 @@ var _ = Describe("Stored Session Suite", func() {
RefreshToken: noRefresh,
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Clock: clock,
}, nil
case "_oauth2_proxy=InvalidNoRefreshSession":
return &sessionsapi.SessionState{
@ -127,24 +128,28 @@ var _ = Describe("Stored Session Suite", func() {
RefreshToken: noRefresh,
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Clock: clock,
}, nil
case "_oauth2_proxy=ExpiredNoRefreshSession":
return &sessionsapi.SessionState{
RefreshToken: noRefresh,
CreatedAt: &createdPast,
ExpiresOn: &createdPast,
Clock: clock,
}, nil
case "_oauth2_proxy=RefreshSession":
return &sessionsapi.SessionState{
RefreshToken: refresh,
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Clock: clock,
}, nil
case "_oauth2_proxy=RefreshError":
return &sessionsapi.SessionState{
RefreshToken: "RefreshError",
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Clock: clock,
}, nil
case "_oauth2_proxy=NonExistent":
return nil, fmt.Errorf("invalid cookie")
@ -154,14 +159,6 @@ var _ = Describe("Stored Session Suite", func() {
},
}
BeforeEach(func() {
clock.Set(now)
})
AfterEach(func() {
clock.Reset()
})
type storedSessionLoaderTableInput struct {
requestHeaders http.Header
existingSession *sessionsapi.SessionState
@ -200,7 +197,15 @@ var _ = Describe("Stored Session Suite", func() {
}))
handler.ServeHTTP(rw, req)
Expect(gotSession).To(Equal(in.expectedSession))
// Compare, ignoring testing Clock.
if in.expectedSession == nil {
Expect(gotSession).To(BeNil())
return
}
Expect(gotSession).ToNot(BeNil())
got := *gotSession
got.Clock = nil
Expect(&got).To(Equal(in.expectedSession))
},
Entry("with no cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{},

View File

@ -17,8 +17,9 @@ func TestRefresh(t *testing.T) {
now := time.Unix(1234567890, 10)
expires := time.Unix(1234567890, 0)
ss := &sessions.SessionState{}
ss.Clock.Set(now)
ss := &sessions.SessionState{
Clock: func() time.Time { return now },
}
ss.SetExpiresOn(expires)
refreshed, err := p.RefreshSession(context.Background(), ss)