feat: refresh token on demand

Co-Authored-By: Chadwick Baatz <40895835+cmbaatz@users.noreply.github.com>
Co-Authored-By: Jan Larwig <jan@larwig.com>
Signed-off-by: Jan Larwig <jan@larwig.com>
This commit is contained in:
Jan Larwig 2025-07-17 09:44:45 +02:00
parent 0e1dc9bb84
commit b0ae034d37
No known key found for this signature in database
GPG Key ID: C2172BFA220A037A
5 changed files with 296 additions and 9 deletions

View File

@ -12,6 +12,7 @@
- [#3116](https://github.com/oauth2-proxy/oauth2-proxy/pull/3116) feat: bump to go1.24.5 and full dependency update (@wardviaene / @dolmen)
- [#3097](https://github.com/oauth2-proxy/oauth2-proxy/pull/3097) chore(deps): update alpine base image to v3.22.0
- [#3101](https://github.com/oauth2-proxy/oauth2-proxy/pull/3101) fix: return error for empty Redis URL list (@dgivens)
- [#2019](https://github.com/oauth2-proxy/oauth2-proxy/issues/2019) feat: refresh token on demand (@cmbaatz)
# V7.9.0

View File

@ -53,6 +53,7 @@ const (
oauthCallbackPath = "/callback"
authOnlyPath = "/auth"
userInfoPath = "/userinfo"
refreshPath = "/refresh"
staticPathPrefix = "/static/"
)
@ -104,6 +105,7 @@ type OAuthProxy struct {
trustedIPs *ip.NetSet
sessionChain alice.Chain
refreshChain alice.Chain
headersChain alice.Chain
preAuthChain alice.Chain
pageWriter pagewriter.Writer
@ -204,7 +206,21 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
}
sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator)
sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator).Append(
middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
}))
refreshChain := buildSessionChain(opts, provider, basicAuthValidator).Append(
middleware.NewStoredSessionRefresher(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
}))
headersChain, err := buildHeadersChain(opts)
if err != nil {
return nil, fmt.Errorf("could not build headers chain: %v", err)
@ -241,6 +257,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
basicAuthValidator: basicAuthValidator,
basicAuthGroups: opts.HtpasswdUserGroups,
sessionChain: sessionChain,
refreshChain: refreshChain,
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
@ -347,6 +364,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {
// The userinfo and logout endpoints needs to load sessions before handling the request
s.Path(userInfoPath).Handler(p.sessionChain.ThenFunc(p.UserInfo))
s.Path(refreshPath).Handler(p.refreshChain.ThenFunc(p.Proxy))
s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(p.SignOut))
}

View File

@ -2006,6 +2006,10 @@ func Test_noCacheHeaders(t *testing.T) {
path: "/oauth2/userinfo",
hasNoCache: true,
},
{
path: "/oauth2/refresh",
hasNoCache: true,
},
{
path: "/upstream",
hasNoCache: false,

View File

@ -64,6 +64,20 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
return ss.loadSession
}
// NewStoredSessionRefresher creates a new storedSessionLoader which allows for the
// refresh of sessions from the session store on demand.
// If no session is found, the request will be passed to the nex handler.
// If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionRefresher(opts *StoredSessionLoaderOptions) alice.Constructor {
sr := &storedSessionLoader{
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
sessionRefresher: opts.RefreshSession,
sessionValidator: opts.ValidateSession,
}
return sr.forceRefreshSession
}
// storedSessionLoader is responsible for loading sessions from cookie
// identified sessions in the session store.
type storedSessionLoader struct {
@ -113,7 +127,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
return nil, err
}
err = s.refreshSessionIfNeeded(rw, req, session)
err = s.refreshSessionIfNeeded(rw, req, session, false)
if err != nil {
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
}
@ -124,8 +138,8 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
// refreshSessionIfNeeded will attempt to refresh a session if the session
// is older than the refresh period.
// Success or fail, we will then validate the session.
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
if !needsRefresh(s.refreshPeriod, session) {
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState, forceRefresh bool) error {
if !forceRefresh && !needsRefresh(s.refreshPeriod, session) {
// Refresh is disabled or the session is not old enough, do nothing
return nil
}
@ -179,7 +193,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
// Loading from the session store creates a new lock in the session.
session.Lock = lock
if !needsRefresh(s.refreshPeriod, session) {
if !forceRefresh && !needsRefresh(s.refreshPeriod, session) {
// The session must have already been refreshed while we were waiting to
// obtain the lock.
return nil
@ -252,3 +266,58 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
return nil
}
// forceRefreshSession attempts to load a session as identified by the request cookies.
// If no session is found, the request will be passed to the next handler.
// If a session was loaded by a previous handler, it will not be refreshed.
func (s *storedSessionLoader) forceRefreshSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := middlewareapi.GetRequestScope(req)
// If refreshPeriod isn't defined then forward request to the next handler
if s.refreshPeriod <= time.Duration(0) {
next.ServeHTTP(rw, req)
return
}
session, err := s.store.Load(req)
if err != nil {
s.handleRefreshError(rw, req, next, err)
return
}
if session == nil {
next.ServeHTTP(rw, req)
return
}
err = s.refreshSessionIfNeeded(rw, req, session, true)
if err != nil {
s.handleRefreshError(rw, req, next, err)
return
}
// Add the session to the scope if it was found
scope.Session = session
next.ServeHTTP(rw, req)
})
}
func (s *storedSessionLoader) handleRefreshError(rw http.ResponseWriter, req *http.Request, next http.Handler, err error) {
if errors.Is(err, http.ErrNoCookie) {
logger.Error("Error refreshing session: session cookie not present")
} else {
// In the case when there was an error loading the session,
// we should clear the session
logger.Errorf("Error loading cookied session: %v, removing session", err)
if err = s.store.Clear(rw, req); err != nil {
logger.Errorf("Error removing session: %v", err)
}
}
scope := middlewareapi.GetRequestScope(req)
scope.Session = nil
next.ServeHTTP(rw, req)
}

View File

@ -85,6 +85,7 @@ var _ = Describe("Stored Session Suite", func() {
const (
refresh = "Refresh"
refreshed = "Refreshed"
forcedRefresh = "Forced-Refresh"
noRefresh = "NoRefresh"
notImplemented = "NotImplemented"
)
@ -94,13 +95,18 @@ var _ = Describe("Stored Session Suite", func() {
Context("StoredSessionLoader", func() {
now := time.Now()
createdPast := now.Add(-5 * time.Minute)
recentPast := now.Add(-5 * time.Second)
createdFuture := now.Add(5 * time.Minute)
recent := now.Add(1 * time.Minute)
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
switch ss.RefreshToken {
case refresh:
ss.RefreshToken = refreshed
return true, nil
case forcedRefresh:
ss.RefreshToken = refreshed
return true, nil
case noRefresh:
return false, nil
default:
@ -146,6 +152,18 @@ var _ = Describe("Stored Session Suite", func() {
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
}, nil
case "_oauth2_proxy=NewSession":
return &sessionsapi.SessionState{
RefreshToken: "Forced-Refresh", // used to inform test to allow refresh. Will be overwritten
CreatedAt: &recentPast,
ExpiresOn: &createdFuture,
}, nil
case "_oauth2_proxy=OldSession":
return &sessionsapi.SessionState{
RefreshToken: "Forced-Refresh", // used to inform test to allow refresh. Will be overwritten
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
}, nil
case "_oauth2_proxy=NonExistent":
return nil, fmt.Errorf("invalid cookie")
default:
@ -155,7 +173,7 @@ var _ = Describe("Stored Session Suite", func() {
}
BeforeEach(func() {
clock.Set(now)
clock.Set(recent)
})
AfterEach(func() {
@ -172,7 +190,7 @@ var _ = Describe("Stored Session Suite", func() {
validateSession func(context.Context, *sessionsapi.SessionState) bool
}
DescribeTable("when serving a request",
DescribeTable("when serving a loadSession request",
func(in storedSessionLoaderTableInput) {
scope := &middlewareapi.RequestScope{
Session: in.existingSession,
@ -286,7 +304,7 @@ var _ = Describe("Stored Session Suite", func() {
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed",
CreatedAt: &now,
CreatedAt: &recent,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
@ -335,6 +353,183 @@ var _ = Describe("Stored Session Suite", func() {
}),
)
DescribeTable("when service a forceRefresh request",
func(in storedSessionLoaderTableInput) {
scope := &middlewareapi.RequestScope{
Session: in.existingSession,
}
// Set up the request with the request header and a request scope
req := httptest.NewRequest("", "/", nil)
req.Header = in.requestHeaders
req = middlewareapi.AddRequestScope(req, scope)
rw := httptest.NewRecorder()
opts := &StoredSessionLoaderOptions{
SessionStore: in.store,
RefreshPeriod: in.refreshPeriod,
RefreshSession: in.refreshSession,
ValidateSession: in.validateSession,
}
// Create the handler with a next handler that will capture the session
// from the scope
var gotSession *sessionsapi.SessionState
handler := NewStoredSessionRefresher(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = middlewareapi.GetRequestScope(r).Session
}))
handler.ServeHTTP(rw, req)
Expect(gotSession).To(Equal(in.expectedSession))
},
Entry("with no cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with an invalid cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NonExistent"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with an existing session", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NewSession"},
},
existingSession: &sessionsapi.SessionState{
RefreshToken: "Forced-Refresh",
CreatedAt: &recentPast,
},
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed",
CreatedAt: &recent,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that has not expired and cannot be refreshed", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NoRefreshSession"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: noRefresh,
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NewSession"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed",
CreatedAt: &recent,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=OldSession"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed",
CreatedAt: &recent,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "RefreshError",
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false },
}),
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when refresh period is not defined", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NewSession"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 0,
}),
)
type storedSessionLoaderConcurrentTableInput struct {
existingSession *sessionsapi.SessionState
refreshPeriod time.Duration
@ -496,7 +691,7 @@ var _ = Describe("Stored Session Suite", func() {
}
req := httptest.NewRequest("", "/", nil)
err := s.refreshSessionIfNeeded(nil, req, in.session)
err := s.refreshSessionIfNeeded(nil, req, in.session, false)
if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr))
} else {