test: replace mock pkg/clock with narrowly targeted stub clocks. (#3238)
The package under pkg/clock is github.com/benbjohnson/clock, which is archived. It's also way more complex than is what is actually needed here, so we can entirely remove the dependency and remove the helper package. Fixes #2840. Signed-off-by: David Symonds <dsymonds@gmail.com>
This commit is contained in:
parent
8f687e4d0c
commit
110d51d1d7
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
- [#3228](https://github.com/oauth2-proxy/oauth2-proxy/pull/3228) fix: use GetSecret() in ticket.go makeCookie to respect cookie-secret-file (@stagswtf)
|
||||
- [#3244](https://github.com/oauth2-proxy/oauth2-proxy/pull/3244) chore(deps): upgrade to latest go1.25.3 (@tuunit)
|
||||
- [#3238](https://github.com/oauth2-proxy/oauth2-proxy/pull/3238) chore: Replace pkg/clock with narrowly targeted stub clocks (@dsymonds)
|
||||
|
||||
# V7.12.0
|
||||
|
||||
|
|
|
|||
1
go.mod
1
go.mod
|
|
@ -7,7 +7,6 @@ require (
|
|||
github.com/Bose/minisentinel v0.0.0-20200130220412-917c5a9223bb
|
||||
github.com/a8m/envsubst v1.4.3
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/benbjohnson/clock v1.3.5
|
||||
github.com/bitly/go-simplejson v0.5.1
|
||||
github.com/bsm/redislock v0.9.4
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -14,8 +14,6 @@ github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGn
|
|||
github.com/alicebob/miniredis/v2 v2.11.1/go.mod h1:UA48pmi7aSazcGAvcdKcBB49z521IC9VjTTRz2nIaJE=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o=
|
||||
github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bitly/go-simplejson v0.5.1 h1:xgwPbetQScXt1gh9BmoJ6j9JMr3TElvuIyjR8pgdoow=
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
||||
"github.com/pierrec/lz4/v4"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
|
|
@ -30,10 +29,17 @@ type SessionState struct {
|
|||
PreferredUsername string `msgpack:"pu,omitempty"`
|
||||
|
||||
// Internal helpers, not serialized
|
||||
Clock clock.Clock `msgpack:"-"`
|
||||
Clock func() time.Time `msgpack:"-"` // override for time.Now, for testing
|
||||
Lock Lock `msgpack:"-"`
|
||||
}
|
||||
|
||||
func (s *SessionState) now() time.Time {
|
||||
if s.Clock != nil {
|
||||
return s.Clock()
|
||||
}
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error {
|
||||
if s.Lock == nil {
|
||||
s.Lock = &NoOpLock{}
|
||||
|
|
@ -64,7 +70,7 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) {
|
|||
|
||||
// CreatedAtNow sets a SessionState's CreatedAt to now
|
||||
func (s *SessionState) CreatedAtNow() {
|
||||
now := s.Clock.Now()
|
||||
now := s.now()
|
||||
s.CreatedAt = &now
|
||||
}
|
||||
|
||||
|
|
@ -85,7 +91,7 @@ func (s *SessionState) ExpiresIn(d time.Duration) {
|
|||
|
||||
// IsExpired checks whether the session has expired
|
||||
func (s *SessionState) IsExpired() bool {
|
||||
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) {
|
||||
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.now()) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
|
@ -94,7 +100,7 @@ func (s *SessionState) IsExpired() bool {
|
|||
// Age returns the age of a session
|
||||
func (s *SessionState) Age() time.Duration {
|
||||
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
|
||||
return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt)
|
||||
return s.now().Truncate(time.Second).Sub(*s.CreatedAt)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ func TestCreatedAtNow(t *testing.T) {
|
|||
ss := &SessionState{}
|
||||
|
||||
now := time.Unix(1234567890, 0)
|
||||
ss.Clock.Set(now)
|
||||
ss.Clock = func() time.Time { return now }
|
||||
|
||||
ss.CreatedAtNow()
|
||||
g.Expect(*ss.CreatedAt).To(Equal(now))
|
||||
|
|
@ -33,9 +33,9 @@ func TestExpiresIn(t *testing.T) {
|
|||
ss := &SessionState{}
|
||||
|
||||
now := time.Unix(1234567890, 0)
|
||||
ss.Clock.Set(now)
|
||||
ss.Clock = func() time.Time { return now }
|
||||
|
||||
ttl := time.Duration(743) * time.Second
|
||||
ttl := 743 * time.Second
|
||||
ss.ExpiresIn(ttl)
|
||||
|
||||
g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl)))
|
||||
|
|
|
|||
|
|
@ -1,157 +0,0 @@
|
|||
package clock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
clockapi "github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
var (
|
||||
globalClock = clockapi.New()
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
// Set the global clock to a clockapi.Mock with the given time.Time
|
||||
func Set(t time.Time) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
mock, ok := globalClock.(*clockapi.Mock)
|
||||
if !ok {
|
||||
mock = clockapi.NewMock()
|
||||
}
|
||||
mock.Set(t)
|
||||
globalClock = mock
|
||||
}
|
||||
|
||||
// Add moves the mocked global clock forward the given duration. It will error
|
||||
// if the global clock is not mocked.
|
||||
func Add(d time.Duration) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
mock, ok := globalClock.(*clockapi.Mock)
|
||||
if !ok {
|
||||
return errors.New("time not mocked")
|
||||
}
|
||||
mock.Add(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset sets the global clock to a pure time implementation. Returns any
|
||||
// existing Mock if set in case lingering time operations are attached to it.
|
||||
func Reset() *clockapi.Mock {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
existing := globalClock
|
||||
globalClock = clockapi.New()
|
||||
|
||||
mock, ok := existing.(*clockapi.Mock)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return mock
|
||||
}
|
||||
|
||||
// Clock is a non-package level wrapper around time that supports stubbing.
|
||||
// It will use its localized stubs (allowing for parallelized unit tests
|
||||
// where package level stubbing would cause issues). It falls back to any
|
||||
// package level time stubs for non-parallel, cross-package integration
|
||||
// testing scenarios.
|
||||
//
|
||||
// If nothing is stubbed, it defaults to default time behavior in the time
|
||||
// package.
|
||||
type Clock struct {
|
||||
mock *clockapi.Mock
|
||||
}
|
||||
|
||||
// Set sets the Clock to a clock.Mock at the given time.Time
|
||||
func (c *Clock) Set(t time.Time) {
|
||||
if c.mock == nil {
|
||||
c.mock = clockapi.NewMock()
|
||||
}
|
||||
c.mock.Set(t)
|
||||
}
|
||||
|
||||
// Add moves clock forward time.Duration if it is mocked. It will error
|
||||
// if the clock is not mocked.
|
||||
func (c *Clock) Add(d time.Duration) error {
|
||||
if c.mock == nil {
|
||||
return errors.New("clock not mocked")
|
||||
}
|
||||
c.mock.Add(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset removes local clock.Mock. Returns any existing Mock if set in case
|
||||
// lingering time operations are attached to it.
|
||||
func (c *Clock) Reset() *clockapi.Mock {
|
||||
existing := c.mock
|
||||
c.mock = nil
|
||||
return existing
|
||||
}
|
||||
|
||||
func (c *Clock) After(d time.Duration) <-chan time.Time {
|
||||
m := c.mock
|
||||
if m == nil {
|
||||
return globalClock.After(d)
|
||||
}
|
||||
return m.After(d)
|
||||
}
|
||||
|
||||
func (c *Clock) AfterFunc(d time.Duration, f func()) *clockapi.Timer {
|
||||
m := c.mock
|
||||
if m == nil {
|
||||
return globalClock.AfterFunc(d, f)
|
||||
}
|
||||
return m.AfterFunc(d, f)
|
||||
}
|
||||
|
||||
func (c *Clock) Now() time.Time {
|
||||
m := c.mock
|
||||
if m == nil {
|
||||
return globalClock.Now()
|
||||
}
|
||||
return m.Now()
|
||||
}
|
||||
|
||||
func (c *Clock) Since(t time.Time) time.Duration {
|
||||
m := c.mock
|
||||
if m == nil {
|
||||
return globalClock.Since(t)
|
||||
}
|
||||
return m.Since(t)
|
||||
}
|
||||
|
||||
func (c *Clock) Sleep(d time.Duration) {
|
||||
m := c.mock
|
||||
if m == nil {
|
||||
globalClock.Sleep(d)
|
||||
return
|
||||
}
|
||||
m.Sleep(d)
|
||||
}
|
||||
|
||||
func (c *Clock) Tick(d time.Duration) <-chan time.Time {
|
||||
m := c.mock
|
||||
if m == nil {
|
||||
return globalClock.Tick(d)
|
||||
}
|
||||
return m.Tick(d)
|
||||
}
|
||||
|
||||
func (c *Clock) Ticker(d time.Duration) *clockapi.Ticker {
|
||||
m := c.mock
|
||||
if m == nil {
|
||||
return globalClock.Ticker(d)
|
||||
}
|
||||
return m.Ticker(d)
|
||||
}
|
||||
|
||||
func (c *Clock) Timer(d time.Duration) *clockapi.Timer {
|
||||
m := c.mock
|
||||
if m == nil {
|
||||
return globalClock.Timer(d)
|
||||
}
|
||||
return m.Timer(d)
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
package clock_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestClockSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
logger.SetErrOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Clock")
|
||||
}
|
||||
|
|
@ -1,380 +0,0 @@
|
|||
package clock_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
testGlobalEpoch = 1000000000
|
||||
testLocalEpoch = 1234567890
|
||||
)
|
||||
|
||||
var _ = Describe("Clock suite", func() {
|
||||
var testClock = clock.Clock{}
|
||||
|
||||
AfterEach(func() {
|
||||
clock.Reset()
|
||||
testClock.Reset()
|
||||
})
|
||||
|
||||
Context("Global time not overridden", func() {
|
||||
It("errors when trying to Add", func() {
|
||||
err := clock.Add(123 * time.Hour)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("Clock not mocked via Set", func() {
|
||||
const (
|
||||
outsideTolerance = int32(0)
|
||||
withinTolerance = int32(1)
|
||||
)
|
||||
|
||||
It("uses time.After for After", func() {
|
||||
var tolerance int32
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, withinTolerance)
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, outsideTolerance)
|
||||
}()
|
||||
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
|
||||
<-testClock.After(20 * time.Millisecond)
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
|
||||
|
||||
<-testClock.After(20 * time.Millisecond)
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
})
|
||||
|
||||
It("uses time.AfterFunc for AfterFunc", func() {
|
||||
var tolerance int32
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, withinTolerance)
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, outsideTolerance)
|
||||
}()
|
||||
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
testClock.AfterFunc(20*time.Millisecond, func() {
|
||||
wg.Done()
|
||||
})
|
||||
wg.Wait()
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
|
||||
|
||||
wg.Add(1)
|
||||
testClock.AfterFunc(20*time.Millisecond, func() {
|
||||
wg.Done()
|
||||
})
|
||||
wg.Wait()
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
})
|
||||
|
||||
It("uses time.Now for Now", func() {
|
||||
a := time.Now()
|
||||
b := testClock.Now()
|
||||
Expect(b.Sub(a).Round(10 * time.Millisecond)).To(Equal(0 * time.Millisecond))
|
||||
})
|
||||
|
||||
It("uses time.Since for Since", func() {
|
||||
past := time.Now().Add(-60 * time.Second)
|
||||
Expect(time.Since(past).Round(10 * time.Millisecond)).
|
||||
To(Equal(60 * time.Second))
|
||||
})
|
||||
|
||||
It("uses time.Sleep for Sleep", func() {
|
||||
var tolerance int32
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, withinTolerance)
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, outsideTolerance)
|
||||
}()
|
||||
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
|
||||
testClock.Sleep(20 * time.Millisecond)
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
|
||||
|
||||
testClock.Sleep(20 * time.Millisecond)
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
})
|
||||
|
||||
It("uses time.Tick for Tick", func() {
|
||||
var tolerance int32
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, withinTolerance)
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, outsideTolerance)
|
||||
}()
|
||||
|
||||
ch := testClock.Tick(20 * time.Millisecond)
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
<-ch
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
|
||||
<-ch
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
|
||||
<-ch
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
})
|
||||
|
||||
It("uses time.Ticker for Ticker", func() {
|
||||
var tolerance int32
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, withinTolerance)
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
atomic.StoreInt32(&tolerance, outsideTolerance)
|
||||
}()
|
||||
|
||||
ticker := testClock.Ticker(20 * time.Millisecond)
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
<-ticker.C
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
|
||||
<-ticker.C
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(withinTolerance))
|
||||
<-ticker.C
|
||||
Expect(atomic.LoadInt32(&tolerance)).To(Equal(outsideTolerance))
|
||||
})
|
||||
|
||||
It("errors if Add is used", func() {
|
||||
err := testClock.Add(100 * time.Second)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Clock mocked via Set", func() {
|
||||
var now = time.Unix(testLocalEpoch, 0)
|
||||
|
||||
BeforeEach(func() {
|
||||
testClock.Set(now)
|
||||
})
|
||||
|
||||
It("mocks After", func() {
|
||||
var after int32
|
||||
ready := make(chan struct{})
|
||||
ch := testClock.After(10 * time.Second)
|
||||
go func(ch <-chan time.Time) {
|
||||
close(ready)
|
||||
<-ch
|
||||
atomic.StoreInt32(&after, 1)
|
||||
}(ch)
|
||||
<-ready
|
||||
|
||||
err := testClock.Add(9 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
|
||||
|
||||
err = testClock.Add(1 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("mocks AfterFunc", func() {
|
||||
var after int32
|
||||
testClock.AfterFunc(10*time.Second, func() {
|
||||
atomic.StoreInt32(&after, 1)
|
||||
})
|
||||
|
||||
err := testClock.Add(9 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
|
||||
|
||||
err = testClock.Add(1 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("mocks AfterFunc with a stopped timer", func() {
|
||||
var after int32
|
||||
timer := testClock.AfterFunc(10*time.Second, func() {
|
||||
atomic.StoreInt32(&after, 1)
|
||||
})
|
||||
timer.Stop()
|
||||
|
||||
err := testClock.Add(11 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
|
||||
})
|
||||
|
||||
It("mocks Now", func() {
|
||||
Expect(testClock.Now()).To(Equal(now))
|
||||
err := testClock.Add(123 * time.Hour)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(testClock.Now()).To(Equal(now.Add(123 * time.Hour)))
|
||||
})
|
||||
|
||||
It("mocks Since", func() {
|
||||
Expect(testClock.Since(time.Unix(testLocalEpoch-100, 0))).
|
||||
To(Equal(100 * time.Second))
|
||||
})
|
||||
|
||||
It("mocks Sleep", func() {
|
||||
var after int32
|
||||
ready := make(chan struct{})
|
||||
go func() {
|
||||
close(ready)
|
||||
testClock.Sleep(10 * time.Second)
|
||||
atomic.StoreInt32(&after, 1)
|
||||
}()
|
||||
<-ready
|
||||
|
||||
err := testClock.Add(9 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
|
||||
|
||||
err = testClock.Add(1 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("mocks Tick", func() {
|
||||
var ticks int32
|
||||
ready := make(chan struct{})
|
||||
go func() {
|
||||
close(ready)
|
||||
tick := testClock.Tick(10 * time.Second)
|
||||
for ticks < 5 {
|
||||
<-tick
|
||||
atomic.AddInt32(&ticks, 1)
|
||||
}
|
||||
}()
|
||||
<-ready
|
||||
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0)))
|
||||
|
||||
err := testClock.Add(9 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0)))
|
||||
|
||||
err = testClock.Add(1 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(1)))
|
||||
|
||||
err = testClock.Add(30 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(4)))
|
||||
|
||||
err = testClock.Add(10 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(5)))
|
||||
})
|
||||
|
||||
It("mocks Ticker", func() {
|
||||
var ticks int32
|
||||
ready := make(chan struct{})
|
||||
go func() {
|
||||
ticker := testClock.Ticker(10 * time.Second)
|
||||
close(ready)
|
||||
for ticks < 5 {
|
||||
<-ticker.C
|
||||
atomic.AddInt32(&ticks, 1)
|
||||
}
|
||||
}()
|
||||
<-ready
|
||||
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0)))
|
||||
|
||||
err := testClock.Add(9 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(0)))
|
||||
|
||||
err = testClock.Add(1 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(1)))
|
||||
|
||||
err = testClock.Add(30 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(4)))
|
||||
|
||||
err = testClock.Add(10 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&ticks)).To(Equal(int32(5)))
|
||||
})
|
||||
|
||||
It("mocks Timer", func() {
|
||||
var after int32
|
||||
ready := make(chan struct{})
|
||||
go func() {
|
||||
timer := testClock.Timer(10 * time.Second)
|
||||
close(ready)
|
||||
<-timer.C
|
||||
atomic.AddInt32(&after, 1)
|
||||
}()
|
||||
<-ready
|
||||
|
||||
err := testClock.Add(9 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(0)))
|
||||
|
||||
err = testClock.Add(1 * time.Second)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(atomic.LoadInt32(&after)).To(Equal(int32(1)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Global time overridden", func() {
|
||||
var (
|
||||
globalNow = time.Unix(testGlobalEpoch, 0)
|
||||
localNow = time.Unix(testLocalEpoch, 0)
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
clock.Set(globalNow)
|
||||
})
|
||||
|
||||
Context("Clock not mocked via Set", func() {
|
||||
It("uses globally mocked Now", func() {
|
||||
Expect(testClock.Now()).To(Equal(globalNow))
|
||||
err := clock.Add(123 * time.Hour)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(testClock.Now()).To(Equal(globalNow.Add(123 * time.Hour)))
|
||||
})
|
||||
|
||||
It("errors when Add is called on the local Clock", func() {
|
||||
err := testClock.Add(100 * time.Hour)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Clock is mocked via Set", func() {
|
||||
BeforeEach(func() {
|
||||
testClock.Set(localNow)
|
||||
})
|
||||
|
||||
It("uses the local mock and ignores the global", func() {
|
||||
Expect(testClock.Now()).To(Equal(localNow))
|
||||
|
||||
err := clock.Add(456 * time.Hour)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = testClock.Add(123 * time.Hour)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(testClock.Now()).To(Equal(localNow.Add(123 * time.Hour)))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -10,7 +10,6 @@ import (
|
|||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
|
@ -47,7 +46,7 @@ type csrf struct {
|
|||
CodeVerifier string `msgpack:"cv,omitempty"`
|
||||
|
||||
cookieOpts *options.Cookie
|
||||
time clock.Clock
|
||||
clock func() time.Time
|
||||
}
|
||||
|
||||
// csrtStateTrim will indicate the length of the state trimmed for the name of the csrf cookie
|
||||
|
|
@ -70,6 +69,7 @@ func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) {
|
|||
CodeVerifier: codeVerifier,
|
||||
|
||||
cookieOpts: opts,
|
||||
clock: time.Now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -187,7 +187,7 @@ func ClearExtraCsrfCookies(opts *options.Cookie, rw http.ResponseWriter, req *ht
|
|||
|
||||
// delete the X oldest cookies
|
||||
slices.SortStableFunc(decodedCookies, func(a, b *csrf) int {
|
||||
return a.time.Now().Compare(b.time.Now())
|
||||
return a.clock().Compare(b.clock())
|
||||
})
|
||||
|
||||
for i := 0; i < len(decodedCookies)-opts.CSRFPerRequestLimit; i++ {
|
||||
|
|
@ -223,7 +223,7 @@ func (c *csrf) encodeCookie() (string, error) {
|
|||
if err != nil {
|
||||
return "", fmt.Errorf("error getting cookie secret: %v", err)
|
||||
}
|
||||
return encryption.SignedValue(secret, c.cookieName(), encrypted, c.time.Now())
|
||||
return encryption.SignedValue(secret, c.cookieName(), encrypted, c.clock())
|
||||
}
|
||||
|
||||
// decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
|
||||
|
|
@ -249,10 +249,10 @@ func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error)
|
|||
|
||||
// unmarshalCSRF unmarshals decrypted data into a CSRF struct
|
||||
func unmarshalCSRF(decrypted []byte, opts *options.Cookie, csrfTime time.Time) (*csrf, error) {
|
||||
clock := clock.Clock{}
|
||||
clock.Set(csrfTime)
|
||||
|
||||
csrf := &csrf{cookieOpts: opts, time: clock}
|
||||
csrf := &csrf{
|
||||
cookieOpts: opts,
|
||||
clock: func() time.Time { return csrfTime },
|
||||
}
|
||||
if err := msgpack.Unmarshal(decrypted, csrf); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() {
|
|||
testNow := time.Unix(nowEpoch, 0)
|
||||
|
||||
BeforeEach(func() {
|
||||
privateCSRF.time.Set(testNow)
|
||||
privateCSRF.clock = func() time.Time { return testNow }
|
||||
|
||||
req = &http.Request{
|
||||
Method: http.MethodGet,
|
||||
|
|
@ -144,7 +144,7 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() {
|
|||
})
|
||||
|
||||
AfterEach(func() {
|
||||
privateCSRF.time.Reset()
|
||||
privateCSRF.clock = time.Now
|
||||
})
|
||||
|
||||
Context("SetCookie", func() {
|
||||
|
|
@ -200,17 +200,17 @@ var _ = Describe("CSRF Cookie with non-fixed name Tests", func() {
|
|||
publicCSRF1, err := NewCSRF(cookieOpts, "verifier")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
privateCSRF1 := publicCSRF1.(*csrf)
|
||||
privateCSRF1.time.Set(testNow)
|
||||
privateCSRF1.clock = func() time.Time { return testNow }
|
||||
|
||||
publicCSRF2, err := NewCSRF(cookieOpts, "verifier")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
privateCSRF2 := publicCSRF2.(*csrf)
|
||||
privateCSRF2.time.Set(testNow.Add(time.Minute))
|
||||
privateCSRF2.clock = func() time.Time { return testNow.Add(time.Minute) }
|
||||
|
||||
publicCSRF3, err := NewCSRF(cookieOpts, "verifier")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
privateCSRF3 := publicCSRF3.(*csrf)
|
||||
privateCSRF3.time.Set(testNow.Add(time.Minute * 2))
|
||||
privateCSRF3.clock = func() time.Time { return testNow.Add(time.Minute * 2) }
|
||||
|
||||
cookies := []string{}
|
||||
for _, csrf := range []*csrf{privateCSRF1, privateCSRF2, privateCSRF3} {
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ var _ = Describe("CSRF Cookie Tests", func() {
|
|||
testNow := time.Unix(nowEpoch, 0)
|
||||
|
||||
BeforeEach(func() {
|
||||
privateCSRF.time.Set(testNow)
|
||||
privateCSRF.clock = func() time.Time { return testNow }
|
||||
|
||||
req = &http.Request{
|
||||
Method: http.MethodGet,
|
||||
|
|
@ -146,7 +146,7 @@ var _ = Describe("CSRF Cookie Tests", func() {
|
|||
})
|
||||
|
||||
AfterEach(func() {
|
||||
privateCSRF.time.Reset()
|
||||
privateCSRF.clock = time.Now
|
||||
})
|
||||
|
||||
Context("SetCookie", func() {
|
||||
|
|
@ -173,7 +173,7 @@ var _ = Describe("CSRF Cookie Tests", func() {
|
|||
Context("LoadCSRFCookie", func() {
|
||||
BeforeEach(func() {
|
||||
// we need to reset the time to ensure the cookie is valid
|
||||
privateCSRF.time.Reset()
|
||||
privateCSRF.clock = time.Now
|
||||
})
|
||||
|
||||
It("should return error when no cookie is set", func() {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
|
@ -95,6 +94,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
now := time.Now()
|
||||
createdPast := now.Add(-5 * time.Minute)
|
||||
createdFuture := now.Add(5 * time.Minute)
|
||||
clock := func() time.Time { return now }
|
||||
|
||||
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
|
||||
switch ss.RefreshToken {
|
||||
|
|
@ -120,6 +120,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
RefreshToken: noRefresh,
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdFuture,
|
||||
Clock: clock,
|
||||
}, nil
|
||||
case "_oauth2_proxy=InvalidNoRefreshSession":
|
||||
return &sessionsapi.SessionState{
|
||||
|
|
@ -127,24 +128,28 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
RefreshToken: noRefresh,
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdFuture,
|
||||
Clock: clock,
|
||||
}, nil
|
||||
case "_oauth2_proxy=ExpiredNoRefreshSession":
|
||||
return &sessionsapi.SessionState{
|
||||
RefreshToken: noRefresh,
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdPast,
|
||||
Clock: clock,
|
||||
}, nil
|
||||
case "_oauth2_proxy=RefreshSession":
|
||||
return &sessionsapi.SessionState{
|
||||
RefreshToken: refresh,
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdFuture,
|
||||
Clock: clock,
|
||||
}, nil
|
||||
case "_oauth2_proxy=RefreshError":
|
||||
return &sessionsapi.SessionState{
|
||||
RefreshToken: "RefreshError",
|
||||
CreatedAt: &createdPast,
|
||||
ExpiresOn: &createdFuture,
|
||||
Clock: clock,
|
||||
}, nil
|
||||
case "_oauth2_proxy=NonExistent":
|
||||
return nil, fmt.Errorf("invalid cookie")
|
||||
|
|
@ -154,14 +159,6 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
},
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
clock.Set(now)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
clock.Reset()
|
||||
})
|
||||
|
||||
type storedSessionLoaderTableInput struct {
|
||||
requestHeaders http.Header
|
||||
existingSession *sessionsapi.SessionState
|
||||
|
|
@ -200,7 +197,15 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
Expect(gotSession).To(Equal(in.expectedSession))
|
||||
// Compare, ignoring testing Clock.
|
||||
if in.expectedSession == nil {
|
||||
Expect(gotSession).To(BeNil())
|
||||
return
|
||||
}
|
||||
Expect(gotSession).ToNot(BeNil())
|
||||
got := *gotSession
|
||||
got.Clock = nil
|
||||
Expect(&got).To(Equal(in.expectedSession))
|
||||
},
|
||||
Entry("with no cookie", storedSessionLoaderTableInput{
|
||||
requestHeaders: http.Header{},
|
||||
|
|
|
|||
|
|
@ -17,8 +17,9 @@ func TestRefresh(t *testing.T) {
|
|||
now := time.Unix(1234567890, 10)
|
||||
expires := time.Unix(1234567890, 0)
|
||||
|
||||
ss := &sessions.SessionState{}
|
||||
ss.Clock.Set(now)
|
||||
ss := &sessions.SessionState{
|
||||
Clock: func() time.Time { return now },
|
||||
}
|
||||
ss.SetExpiresOn(expires)
|
||||
|
||||
refreshed, err := p.RefreshSession(context.Background(), ss)
|
||||
|
|
|
|||
Loading…
Reference in New Issue