diff --git a/CHANGELOG.md b/CHANGELOG.md index 320ba697..2c877b92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -395,6 +395,7 @@ For detailed information, migration guidance, and security implications, see the - [#3116](https://github.com/oauth2-proxy/oauth2-proxy/pull/3116) feat: bump to go1.24.5 and full dependency update (@wardviaene / @dolmen) - [#3097](https://github.com/oauth2-proxy/oauth2-proxy/pull/3097) chore(deps): update alpine base image to v3.22.0 - [#3101](https://github.com/oauth2-proxy/oauth2-proxy/pull/3101) fix: return error for empty Redis URL list (@dgivens) +- [#2019](https://github.com/oauth2-proxy/oauth2-proxy/issues/2019) feat: refresh token on demand (@cmbaatz) # V7.9.0 diff --git a/oauthproxy.go b/oauthproxy.go index e2357c8d..40e986a8 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -53,6 +53,7 @@ const ( oauthCallbackPath = "/callback" authOnlyPath = "/auth" userInfoPath = "/userinfo" + refreshPath = "/refresh" staticPathPrefix = "/static/" idTokenPlaceholder = "{id_token}" @@ -108,6 +109,7 @@ type OAuthProxy struct { trustedIPs *ip.NetSet sessionChain alice.Chain + refreshChain alice.Chain headersChain alice.Chain preAuthChain alice.Chain pageWriter pagewriter.Writer @@ -209,7 +211,22 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } - sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator) + + sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator).Append( + middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSession: provider.RefreshSession, + ValidateSession: provider.ValidateSession, + })) + refreshChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator).Append( + middleware.NewStoredSessionRefresher(&middleware.StoredSessionLoaderOptions{ + SessionStore: sessionStore, + RefreshPeriod: opts.Cookie.Refresh, + RefreshSession: provider.RefreshSession, + ValidateSession: provider.ValidateSession, + })) + headersChain, err := buildHeadersChain(opts) if err != nil { return nil, fmt.Errorf("could not build headers chain: %v", err) @@ -246,6 +263,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr basicAuthValidator: basicAuthValidator, basicAuthGroups: opts.HtpasswdUserGroups, sessionChain: sessionChain, + refreshChain: refreshChain, headersChain: headersChain, preAuthChain: preAuthChain, pageWriter: pageWriter, @@ -352,6 +370,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // The userinfo and logout endpoints needs to load sessions before handling the request s.Path(userInfoPath).Handler(p.sessionChain.ThenFunc(p.UserInfo)) + s.Path(refreshPath).Handler(p.refreshChain.ThenFunc(p.Proxy)) s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(p.SignOut)) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index e1235a4e..796f426e 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -2066,6 +2066,10 @@ func Test_noCacheHeaders(t *testing.T) { path: "/oauth2/userinfo", hasNoCache: true, }, + { + path: "/oauth2/refresh", + hasNoCache: true, + }, { path: "/upstream", hasNoCache: false, diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 53238f19..53c3fad5 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -92,6 +92,20 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor return ss.loadSession } +// NewStoredSessionRefresher creates a new storedSessionLoader which allows for the +// refresh of sessions from the session store on demand. +// If no session is found, the request will be passed to the nex handler. +// If a session was loader by a previous handler, it will not be replaced. +func NewStoredSessionRefresher(opts *StoredSessionLoaderOptions) alice.Constructor { + sr := &storedSessionLoader{ + store: opts.SessionStore, + refreshPeriod: opts.RefreshPeriod, + sessionRefresher: opts.RefreshSession, + sessionValidator: opts.ValidateSession, + } + return sr.forceRefreshSession +} + // storedSessionLoader is responsible for loading sessions from cookie // identified sessions in the session store. type storedSessionLoader struct { @@ -141,7 +155,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h return nil, err } - err = s.refreshSessionIfNeeded(rw, req, session) + err = s.refreshSessionIfNeeded(rw, req, session, false) if err != nil { return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) } @@ -152,8 +166,8 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h // refreshSessionIfNeeded will attempt to refresh a session if the session // is older than the refresh period. // Success or fail, we will then validate the session. -func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { - if !needsRefresh(s.refreshPeriod, session) { +func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState, forceRefresh bool) error { + if !forceRefresh && !needsRefresh(s.refreshPeriod, session) { // Refresh is disabled or the session is not old enough, do nothing return nil } @@ -207,7 +221,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req // Loading from the session store creates a new lock in the session. session.Lock = lock - if !needsRefresh(s.refreshPeriod, session) { + if !forceRefresh && !needsRefresh(s.refreshPeriod, session) { // The session must have already been refreshed while we were waiting to // obtain the lock. return nil @@ -296,3 +310,58 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess return nil } + +// forceRefreshSession attempts to load a session as identified by the request cookies. +// If no session is found, the request will be passed to the next handler. +// If a session was loaded by a previous handler, it will not be refreshed. +func (s *storedSessionLoader) forceRefreshSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := middlewareapi.GetRequestScope(req) + + // If refreshPeriod isn't defined then forward request to the next handler + if s.refreshPeriod <= time.Duration(0) { + next.ServeHTTP(rw, req) + return + } + + session, err := s.store.Load(req) + if err != nil { + s.handleRefreshError(rw, req, next, err) + return + } + + if session == nil { + next.ServeHTTP(rw, req) + return + } + + err = s.refreshSessionIfNeeded(rw, req, session, true) + if err != nil { + s.handleRefreshError(rw, req, next, err) + return + } + + // Add the session to the scope if it was found + scope.Session = session + next.ServeHTTP(rw, req) + }) +} + +func (s *storedSessionLoader) handleRefreshError(rw http.ResponseWriter, req *http.Request, next http.Handler, err error) { + if errors.Is(err, http.ErrNoCookie) { + logger.Error("Error refreshing session: session cookie not present") + + } else { + // In the case when there was an error loading the session, + // we should clear the session + logger.Errorf("Error loading cookied session: %v, removing session", err) + + if err = s.store.Clear(rw, req); err != nil { + logger.Errorf("Error removing session: %v", err) + } + } + + scope := middlewareapi.GetRequestScope(req) + scope.Session = nil + next.ServeHTTP(rw, req) +} diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index c913a4ef..ecc9c34b 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -85,6 +85,7 @@ var _ = Describe("Stored Session Suite", func() { const ( refresh = "Refresh" refreshed = "Refreshed" + forcedRefresh = "Forced-Refresh" noRefresh = "NoRefresh" notImplemented = "NotImplemented" ) @@ -93,15 +94,23 @@ var _ = Describe("Stored Session Suite", func() { Context("StoredSessionLoader", func() { now := time.Now() + createdPast := now.Add(-5 * time.Minute) + recentPast := now.Add(-5 * time.Second) createdFuture := now.Add(5 * time.Minute) + recentFuture := now.Add(1 * time.Minute) + clock := func() time.Time { return now } + recentFutureClock := func() time.Time { return recentFuture } var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { switch ss.RefreshToken { case refresh: ss.RefreshToken = refreshed return true, nil + case forcedRefresh: + ss.RefreshToken = refreshed + return true, nil case noRefresh: return false, nil default: @@ -143,7 +152,7 @@ var _ = Describe("Stored Session Suite", func() { RefreshToken: refresh, CreatedAt: &createdPast, ExpiresOn: &createdFuture, - Clock: clock, + Clock: recentFutureClock, }, nil case "_oauth2_proxy=RefreshError": return &sessionsapi.SessionState{ @@ -152,6 +161,20 @@ var _ = Describe("Stored Session Suite", func() { ExpiresOn: &createdFuture, Clock: clock, }, nil + case "_oauth2_proxy=NewSession": + return &sessionsapi.SessionState{ + RefreshToken: "Forced-Refresh", // used to inform test to allow refresh. Will be overwritten + CreatedAt: &recentPast, + ExpiresOn: &createdFuture, + Clock: recentFutureClock, + }, nil + case "_oauth2_proxy=OldSession": + return &sessionsapi.SessionState{ + RefreshToken: "Forced-Refresh", // used to inform test to allow refresh. Will be overwritten + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + Clock: recentFutureClock, + }, nil case "_oauth2_proxy=NonExistent": return nil, fmt.Errorf("invalid cookie") default: @@ -170,7 +193,7 @@ var _ = Describe("Stored Session Suite", func() { validateSession func(context.Context, *sessionsapi.SessionState) bool } - DescribeTable("when serving a request", + DescribeTable("when serving a loadSession request", func(in storedSessionLoaderTableInput) { scope := &middlewareapi.RequestScope{ Session: in.existingSession, @@ -198,11 +221,12 @@ var _ = Describe("Stored Session Suite", func() { })) handler.ServeHTTP(rw, req) - // Compare, ignoring testing Clock. if in.expectedSession == nil { Expect(gotSession).To(BeNil()) return } + + // Compare, ignoring testing Clock. Expect(gotSession).ToNot(BeNil()) got := *gotSession got.Clock = nil @@ -292,7 +316,7 @@ var _ = Describe("Stored Session Suite", func() { existingSession: nil, expectedSession: &sessionsapi.SessionState{ RefreshToken: "Refreshed", - CreatedAt: &now, + CreatedAt: &recentFuture, ExpiresOn: &createdFuture, Lock: &sessionsapi.NoOpLock{}, }, @@ -341,6 +365,192 @@ var _ = Describe("Stored Session Suite", func() { }), ) + DescribeTable("when service a forceRefresh request", + func(in storedSessionLoaderTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.existingSession, + } + + // Set up the request with the request header and a request scope + req := httptest.NewRequest("", "/", nil) + req.Header = in.requestHeaders + req = middlewareapi.AddRequestScope(req, scope) + + rw := httptest.NewRecorder() + + opts := &StoredSessionLoaderOptions{ + SessionStore: in.store, + RefreshPeriod: in.refreshPeriod, + RefreshSession: in.refreshSession, + ValidateSession: in.validateSession, + } + + // Create the handler with a next handler that will capture the session + // from the scope + var gotSession *sessionsapi.SessionState + handler := NewStoredSessionRefresher(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSession = middlewareapi.GetRequestScope(r).Session + })) + handler.ServeHTTP(rw, req) + + if in.expectedSession == nil { + Expect(gotSession).To(BeNil()) + return + } + + // Compare, ignoring testing Clock. + Expect(gotSession).ToNot(BeNil()) + got := *gotSession + got.Clock = nil + Expect(&got).To(Equal(in.expectedSession)) + }, + Entry("with no cookie", storedSessionLoaderTableInput{ + requestHeaders: http.Header{}, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with an invalid cookie", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NonExistent"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with an existing session", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshSession"}, + }, + existingSession: &sessionsapi.SessionState{ + RefreshToken: "Forced-Refresh", + CreatedAt: &recentPast, + }, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Refreshed", + CreatedAt: &recentFuture, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that has not expired and cannot be refreshed", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NoRefreshSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: noRefresh, + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NewSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Refreshed", + CreatedAt: &recentFuture, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=OldSession"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "Refreshed", + CreatedAt: &recentFuture, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: &sessionsapi.SessionState{ + RefreshToken: "RefreshError", + CreatedAt: &createdPast, + ExpiresOn: &createdFuture, + Lock: &sessionsapi.NoOpLock{}, + }, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=RefreshError"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false }, + }), + Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 1 * time.Minute, + refreshSession: defaultRefreshFunc, + validateSession: defaultValidateFunc, + }), + Entry("when refresh period is not defined", storedSessionLoaderTableInput{ + requestHeaders: http.Header{ + "Cookie": []string{"_oauth2_proxy=NewSession"}, + }, + existingSession: nil, + expectedSession: nil, + store: defaultSessionStore, + refreshPeriod: 0, + }), + ) + type storedSessionLoaderConcurrentTableInput struct { existingSession *sessionsapi.SessionState refreshPeriod time.Duration @@ -502,7 +712,7 @@ var _ = Describe("Stored Session Suite", func() { } req := httptest.NewRequest("", "/", nil) - err := s.refreshSessionIfNeeded(nil, req, in.session) + err := s.refreshSessionIfNeeded(nil, req, in.session, false) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else {