diff --git a/CHANGELOG.md b/CHANGELOG.md index 72d615e0..d5669286 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - [#3116](https://github.com/oauth2-proxy/oauth2-proxy/pull/3116) feat: bump to go1.24.5 and full dependency update (@wardviaene / @dolmen) - [#3097](https://github.com/oauth2-proxy/oauth2-proxy/pull/3097) chore(deps): update alpine base image to v3.22.0 - [#3101](https://github.com/oauth2-proxy/oauth2-proxy/pull/3101) fix: return error for empty Redis URL list (@dgivens) +- [#2019](https://github.com/oauth2-proxy/oauth2-proxy/issues/2019) feat: refresh token on demand (@cmbaatz) # V7.9.0 diff --git a/oauthproxy.go b/oauthproxy.go index b85c89b4..5535ff43 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -53,6 +53,7 @@ const ( oauthCallbackPath = "/callback" authOnlyPath = "/auth" userInfoPath = "/userinfo" + refreshPath = "/refresh" staticPathPrefix = "/static/" ) @@ -104,6 +105,7 @@ type OAuthProxy struct { trustedIPs *ip.NetSet sessionChain alice.Chain + refreshChain alice.Chain headersChain alice.Chain preAuthChain alice.Chain pageWriter pagewriter.Writer @@ -204,7 +206,21 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } - sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator) + sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator).Append( + middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSession: provider.RefreshSession, + ValidateSession: provider.ValidateSession, + })) + refreshChain := buildSessionChain(opts, provider, basicAuthValidator).Append( + middleware.NewStoredSessionRefresher(&middleware.StoredSessionLoaderOptions{ + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSession: provider.RefreshSession, + ValidateSession: provider.ValidateSession, + })) + headersChain, err := buildHeadersChain(opts) if err != nil { return nil, fmt.Errorf("could not build headers chain: %v", err) @@ -241,6 +257,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr basicAuthValidator: basicAuthValidator, basicAuthGroups: opts.HtpasswdUserGroups, sessionChain: sessionChain, + refreshChain: refreshChain, headersChain: headersChain, preAuthChain: preAuthChain, pageWriter: pageWriter, @@ -347,6 +364,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // The userinfo and logout endpoints needs to load sessions before handling the request s.Path(userInfoPath).Handler(p.sessionChain.ThenFunc(p.UserInfo)) + s.Path(refreshPath).Handler(p.refreshChain.ThenFunc(p.Proxy)) s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(p.SignOut)) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 488b8cea..78206fff 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -2006,6 +2006,10 @@ func Test_noCacheHeaders(t *testing.T) { path: "/oauth2/userinfo", hasNoCache: true, }, + { + path: "/oauth2/refresh", + hasNoCache: true, + }, { path: "/upstream", hasNoCache: false, diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index f861c756..a37d4987 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -64,6 +64,20 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor return ss.loadSession } +// NewStoredSessionRefresher creates a new storedSessionLoader which allows for the +// refresh of sessions from the session store on demand. +// If no session is found, the request will be passed to the nex handler. +// If a session was loader by a previous handler, it will not be replaced. +func NewStoredSessionRefresher(opts *StoredSessionLoaderOptions) alice.Constructor { + sr := &storedSessionLoader{ + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + sessionRefresher: opts.RefreshSession, + sessionValidator: opts.ValidateSession, + } + return sr.forceRefreshSession +} + // storedSessionLoader is responsible for loading sessions from cookie // identified sessions in the session store. type storedSessionLoader struct { @@ -113,7 +127,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h return nil, err } - err = s.refreshSessionIfNeeded(rw, req, session) + err = s.refreshSessionIfNeeded(rw, req, session, false) if err != nil { return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) } @@ -124,8 +138,8 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h // refreshSessionIfNeeded will attempt to refresh a session if the session // is older than the refresh period. // Success or fail, we will then validate the session. -func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { - if !needsRefresh(s.refreshPeriod, session) { +func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState, forceRefresh bool) error { + if !forceRefresh && !needsRefresh(s.refreshPeriod, session) { // Refresh is disabled or the session is not old enough, do nothing return nil } @@ -179,7 +193,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req // Loading from the session store creates a new lock in the session. session.Lock = lock - if !needsRefresh(s.refreshPeriod, session) { + if !forceRefresh && !needsRefresh(s.refreshPeriod, session) { // The session must have already been refreshed while we were waiting to // obtain the lock. return nil @@ -252,3 +266,58 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess return nil } + +// forceRefreshSession attempts to load a session as identified by the request cookies. +// If no session is found, the request will be passed to the next handler. +// If a session was loaded by a previous handler, it will not be refreshed. +func (s *storedSessionLoader) forceRefreshSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := middlewareapi.GetRequestScope(req) + + // If refreshPeriod isn't defined then forward request to the next handler + if s.refreshPeriod <= time.Duration(0) { + next.ServeHTTP(rw, req) + return + } + + session, err := s.store.Load(req) + if err != nil { + s.handleRefreshError(rw, req, next, err) + return + } + + if session == nil { + next.ServeHTTP(rw, req) + return + } + + err = s.refreshSessionIfNeeded(rw, req, session, true) + if err != nil { + s.handleRefreshError(rw, req, next, err) + return + } + + // Add the session to the scope if it was found + scope.Session = session + next.ServeHTTP(rw, req) + }) +} + +func (s *storedSessionLoader) handleRefreshError(rw http.ResponseWriter, req *http.Request, next http.Handler, err error) { + if errors.Is(err, http.ErrNoCookie) { + logger.Error("Error refreshing session: session cookie not present") + + } else { + // In the case when there was an error loading the session, + // we should clear the session + logger.Errorf("Error loading cookied session: %v, removing session", err) + + if err = s.store.Clear(rw, req); err != nil { + logger.Errorf("Error removing session: %v", err) + } + } + + scope := middlewareapi.GetRequestScope(req) + scope.Session = nil + next.ServeHTTP(rw, req) +} diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 904c2028..58b3baa2 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -85,6 +85,7 @@ var _ = Describe("Stored Session Suite", func() { const ( refresh = "Refresh" refreshed = "Refreshed" + forcedRefresh = "Forced-Refresh" noRefresh = "NoRefresh" notImplemented = "NotImplemented" ) @@ -94,13 +95,18 @@ var _ = Describe("Stored Session Suite", func() { Context("StoredSessionLoader", func() { now := time.Now() createdPast := now.Add(-5 * time.Minute) + recentPast := now.Add(-5 * time.Second) createdFuture := now.Add(5 * time.Minute) + recent := now.Add(1 * time.Minute) var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: ss.RefreshToken = refreshed return true, nil + case forcedRefresh: + ss.RefreshToken = refreshed + return true, nil case noRefresh: return false, nil default: @@ -146,6 +152,18 @@ var _ = Describe("Stored Session Suite", func() { CreatedAt: &createdPast, ExpiresOn: &createdFuture, }, nil + case "_oauth2_proxy=NewSession": + return &sessionsapi.SessionState{ + RefreshToken: "Forced-Refresh", // used to inform test to allow refresh. Will be overwritten + CreatedAt: &recentPast, + ExpiresOn: &createdFuture, + }, nil + case "_oauth2_proxy=OldSession": + return &sessionsapi.SessionState{ + RefreshToken: "Forced-Refresh", // used to inform test to allow refresh. Will be overwritten + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + }, nil case "_oauth2_proxy=NonExistent": return nil, fmt.Errorf("invalid cookie") default: @@ -155,7 +173,7 @@ var _ = Describe("Stored Session Suite", func() { } BeforeEach(func() { - clock.Set(now) + clock.Set(recent) }) AfterEach(func() { @@ -172,7 +190,7 @@ var _ = Describe("Stored Session Suite", func() { validateSession func(context.Context, *sessionsapi.SessionState) bool } - DescribeTable("when serving a request", + DescribeTable("when serving a loadSession request", func(in storedSessionLoaderTableInput) { scope := &middlewareapi.RequestScope{ Session: in.existingSession, @@ -286,7 +304,7 @@ var _ = Describe("Stored Session Suite", func() { existingSession: nil, expectedSession: &sessionsapi.SessionState{ RefreshToken: "Refreshed", - CreatedAt: &now, + CreatedAt: &recent, ExpiresOn: &createdFuture, Lock: &sessionsapi.NoOpLock{}, }, @@ -335,6 +353,183 @@ var _ = Describe("Stored Session Suite", func() { }), ) + DescribeTable("when service a forceRefresh request", + func(in storedSessionLoaderTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.existingSession, + } + + // Set up the request with the request header and a request scope + req := httptest.NewRequest("", "/", nil) + req.Header = in.requestHeaders + req = middlewareapi.AddRequestScope(req, scope) + + rw := httptest.NewRecorder() + + opts := &StoredSessionLoaderOptions{ + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: in.refreshSession, + ValidateSession: in.validateSession, + } + + // Create the handler with a next handler that will capture the session + // from the scope + var gotSession *sessionsapi.SessionState + handler := NewStoredSessionRefresher(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSession = middlewareapi.GetRequestScope(r).Session + })) + handler.ServeHTTP(rw, req) + + Expect(gotSession).To(Equal(in.expectedSession)) + }, + Entry("with no cookie", storedSessionLoaderTableInput{ + requestHeaders: http.Header{}, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with an invalid cookie", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NonExistent"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with an existing session", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NewSession"}, + }, + existingSession: &sessionsapi.SessionState{ + RefreshToken: "Forced-Refresh", + CreatedAt: &recentPast, + }, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Refreshed", + CreatedAt: &recent, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that has not expired and cannot be refreshed", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NoRefreshSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NewSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Refreshed", + CreatedAt: &recent, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=OldSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Refreshed", + CreatedAt: &recent, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false }, + }), + Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when refresh period is not defined", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NewSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 0, + }), + ) + type storedSessionLoaderConcurrentTableInput struct { existingSession *sessionsapi.SessionState refreshPeriod time.Duration @@ -496,7 +691,7 @@ var _ = Describe("Stored Session Suite", func() { } req := httptest.NewRequest("", "/", nil) - err := s.refreshSessionIfNeeded(nil, req, in.session) + err := s.refreshSessionIfNeeded(nil, req, in.session, false) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else {