This commit is contained in:
Chadwick Baatz 2025-11-28 14:17:42 -03:00 committed by GitHub
commit a6b990ff08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 313 additions and 10 deletions

View File

@ -142,6 +142,7 @@ For detailed information, migration guidance, and security implications, see the
- [#3116](https://github.com/oauth2-proxy/oauth2-proxy/pull/3116) feat: bump to go1.24.5 and full dependency update (@wardviaene / @dolmen)
- [#3097](https://github.com/oauth2-proxy/oauth2-proxy/pull/3097) chore(deps): update alpine base image to v3.22.0
- [#3101](https://github.com/oauth2-proxy/oauth2-proxy/pull/3101) fix: return error for empty Redis URL list (@dgivens)
- [#2019](https://github.com/oauth2-proxy/oauth2-proxy/issues/2019) feat: refresh token on demand (@cmbaatz)
# V7.9.0

View File

@ -53,6 +53,7 @@ const (
oauthCallbackPath = "/callback"
authOnlyPath = "/auth"
userInfoPath = "/userinfo"
refreshPath = "/refresh"
staticPathPrefix = "/static/"
)
@ -104,6 +105,7 @@ type OAuthProxy struct {
trustedIPs *ip.NetSet
sessionChain alice.Chain
refreshChain alice.Chain
headersChain alice.Chain
preAuthChain alice.Chain
pageWriter pagewriter.Writer
@ -204,7 +206,22 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
}
sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator)
sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator).Append(
middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
}))
refreshChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator).Append(
middleware.NewStoredSessionRefresher(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
}))
headersChain, err := buildHeadersChain(opts)
if err != nil {
return nil, fmt.Errorf("could not build headers chain: %v", err)
@ -241,6 +258,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
basicAuthValidator: basicAuthValidator,
basicAuthGroups: opts.HtpasswdUserGroups,
sessionChain: sessionChain,
refreshChain: refreshChain,
headersChain: headersChain,
preAuthChain: preAuthChain,
pageWriter: pageWriter,
@ -347,6 +365,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {
// The userinfo and logout endpoints needs to load sessions before handling the request
s.Path(userInfoPath).Handler(p.sessionChain.ThenFunc(p.UserInfo))
s.Path(refreshPath).Handler(p.refreshChain.ThenFunc(p.Proxy))
s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(p.SignOut))
}

View File

@ -2007,6 +2007,10 @@ func Test_noCacheHeaders(t *testing.T) {
path: "/oauth2/userinfo",
hasNoCache: true,
},
{
path: "/oauth2/refresh",
hasNoCache: true,
},
{
path: "/upstream",
hasNoCache: false,

View File

@ -64,6 +64,20 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor
return ss.loadSession
}
// NewStoredSessionRefresher creates a new storedSessionLoader which allows for the
// refresh of sessions from the session store on demand.
// If no session is found, the request will be passed to the nex handler.
// If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionRefresher(opts *StoredSessionLoaderOptions) alice.Constructor {
sr := &storedSessionLoader{
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
sessionRefresher: opts.RefreshSession,
sessionValidator: opts.ValidateSession,
}
return sr.forceRefreshSession
}
// storedSessionLoader is responsible for loading sessions from cookie
// identified sessions in the session store.
type storedSessionLoader struct {
@ -113,7 +127,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
return nil, err
}
err = s.refreshSessionIfNeeded(rw, req, session)
err = s.refreshSessionIfNeeded(rw, req, session, false)
if err != nil {
return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
}
@ -124,8 +138,8 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
// refreshSessionIfNeeded will attempt to refresh a session if the session
// is older than the refresh period.
// Success or fail, we will then validate the session.
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
if !needsRefresh(s.refreshPeriod, session) {
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState, forceRefresh bool) error {
if !forceRefresh && !needsRefresh(s.refreshPeriod, session) {
// Refresh is disabled or the session is not old enough, do nothing
return nil
}
@ -179,7 +193,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
// Loading from the session store creates a new lock in the session.
session.Lock = lock
if !needsRefresh(s.refreshPeriod, session) {
if !forceRefresh && !needsRefresh(s.refreshPeriod, session) {
// The session must have already been refreshed while we were waiting to
// obtain the lock.
return nil
@ -252,3 +266,58 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess
return nil
}
// forceRefreshSession attempts to load a session as identified by the request cookies.
// If no session is found, the request will be passed to the next handler.
// If a session was loaded by a previous handler, it will not be refreshed.
func (s *storedSessionLoader) forceRefreshSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := middlewareapi.GetRequestScope(req)
// If refreshPeriod isn't defined then forward request to the next handler
if s.refreshPeriod <= time.Duration(0) {
next.ServeHTTP(rw, req)
return
}
session, err := s.store.Load(req)
if err != nil {
s.handleRefreshError(rw, req, next, err)
return
}
if session == nil {
next.ServeHTTP(rw, req)
return
}
err = s.refreshSessionIfNeeded(rw, req, session, true)
if err != nil {
s.handleRefreshError(rw, req, next, err)
return
}
// Add the session to the scope if it was found
scope.Session = session
next.ServeHTTP(rw, req)
})
}
func (s *storedSessionLoader) handleRefreshError(rw http.ResponseWriter, req *http.Request, next http.Handler, err error) {
if errors.Is(err, http.ErrNoCookie) {
logger.Error("Error refreshing session: session cookie not present")
} else {
// In the case when there was an error loading the session,
// we should clear the session
logger.Errorf("Error loading cookied session: %v, removing session", err)
if err = s.store.Clear(rw, req); err != nil {
logger.Errorf("Error removing session: %v", err)
}
}
scope := middlewareapi.GetRequestScope(req)
scope.Session = nil
next.ServeHTTP(rw, req)
}

View File

@ -84,6 +84,7 @@ var _ = Describe("Stored Session Suite", func() {
const (
refresh = "Refresh"
refreshed = "Refreshed"
forcedRefresh = "Forced-Refresh"
noRefresh = "NoRefresh"
notImplemented = "NotImplemented"
)
@ -92,15 +93,23 @@ var _ = Describe("Stored Session Suite", func() {
Context("StoredSessionLoader", func() {
now := time.Now()
createdPast := now.Add(-5 * time.Minute)
recentPast := now.Add(-5 * time.Second)
createdFuture := now.Add(5 * time.Minute)
recentFuture := now.Add(1 * time.Minute)
clock := func() time.Time { return now }
recentFutureClock := func() time.Time { return recentFuture }
var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) {
switch ss.RefreshToken {
case refresh:
ss.RefreshToken = refreshed
return true, nil
case forcedRefresh:
ss.RefreshToken = refreshed
return true, nil
case noRefresh:
return false, nil
default:
@ -142,7 +151,7 @@ var _ = Describe("Stored Session Suite", func() {
RefreshToken: refresh,
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Clock: clock,
Clock: recentFutureClock,
}, nil
case "_oauth2_proxy=RefreshError":
return &sessionsapi.SessionState{
@ -151,6 +160,20 @@ var _ = Describe("Stored Session Suite", func() {
ExpiresOn: &createdFuture,
Clock: clock,
}, nil
case "_oauth2_proxy=NewSession":
return &sessionsapi.SessionState{
RefreshToken: "Forced-Refresh", // used to inform test to allow refresh. Will be overwritten
CreatedAt: &recentPast,
ExpiresOn: &createdFuture,
Clock: recentFutureClock,
}, nil
case "_oauth2_proxy=OldSession":
return &sessionsapi.SessionState{
RefreshToken: "Forced-Refresh", // used to inform test to allow refresh. Will be overwritten
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Clock: recentFutureClock,
}, nil
case "_oauth2_proxy=NonExistent":
return nil, fmt.Errorf("invalid cookie")
default:
@ -169,7 +192,7 @@ var _ = Describe("Stored Session Suite", func() {
validateSession func(context.Context, *sessionsapi.SessionState) bool
}
DescribeTable("when serving a request",
DescribeTable("when serving a loadSession request",
func(in storedSessionLoaderTableInput) {
scope := &middlewareapi.RequestScope{
Session: in.existingSession,
@ -197,11 +220,12 @@ var _ = Describe("Stored Session Suite", func() {
}))
handler.ServeHTTP(rw, req)
// Compare, ignoring testing Clock.
if in.expectedSession == nil {
Expect(gotSession).To(BeNil())
return
}
// Compare, ignoring testing Clock.
Expect(gotSession).ToNot(BeNil())
got := *gotSession
got.Clock = nil
@ -291,7 +315,7 @@ var _ = Describe("Stored Session Suite", func() {
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed",
CreatedAt: &now,
CreatedAt: &recentFuture,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
@ -340,6 +364,192 @@ var _ = Describe("Stored Session Suite", func() {
}),
)
DescribeTable("when service a forceRefresh request",
func(in storedSessionLoaderTableInput) {
scope := &middlewareapi.RequestScope{
Session: in.existingSession,
}
// Set up the request with the request header and a request scope
req := httptest.NewRequest("", "/", nil)
req.Header = in.requestHeaders
req = middlewareapi.AddRequestScope(req, scope)
rw := httptest.NewRecorder()
opts := &StoredSessionLoaderOptions{
SessionStore: in.store,
RefreshPeriod: in.refreshPeriod,
RefreshSession: in.refreshSession,
ValidateSession: in.validateSession,
}
// Create the handler with a next handler that will capture the session
// from the scope
var gotSession *sessionsapi.SessionState
handler := NewStoredSessionRefresher(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = middlewareapi.GetRequestScope(r).Session
}))
handler.ServeHTTP(rw, req)
if in.expectedSession == nil {
Expect(gotSession).To(BeNil())
return
}
// Compare, ignoring testing Clock.
Expect(gotSession).ToNot(BeNil())
got := *gotSession
got.Clock = nil
Expect(&got).To(Equal(in.expectedSession))
},
Entry("with no cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with an invalid cookie", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NonExistent"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with an existing session", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshSession"},
},
existingSession: &sessionsapi.SessionState{
RefreshToken: "Forced-Refresh",
CreatedAt: &recentPast,
},
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed",
CreatedAt: &recentFuture,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that has not expired and cannot be refreshed", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NoRefreshSession"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: noRefresh,
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NewSession"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed",
CreatedAt: &recentFuture,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=OldSession"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "Refreshed",
CreatedAt: &recentFuture,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "RefreshError",
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false },
}),
Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when refresh period is not defined", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=NewSession"},
},
existingSession: nil,
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 0,
}),
)
type storedSessionLoaderConcurrentTableInput struct {
existingSession *sessionsapi.SessionState
refreshPeriod time.Duration
@ -501,7 +711,7 @@ var _ = Describe("Stored Session Suite", func() {
}
req := httptest.NewRequest("", "/", nil)
err := s.refreshSessionIfNeeded(nil, req, in.session)
err := s.refreshSessionIfNeeded(nil, req, in.session, false)
if in.expectedErr != nil {
Expect(err).To(MatchError(in.expectedErr))
} else {