diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index f861c756..c4b2b718 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "strings" "time" "github.com/justinas/alice" @@ -31,6 +32,34 @@ const ( sessionRefreshRetryPeriod = 10 * time.Millisecond ) +// isFatalRefreshError checks if a refresh error indicates a revoked or +// non-existent session that should be immediately invalidated. +// Fatal errors indicate the session is no longer valid at the provider level. +// Non-fatal errors (network issues, timeouts) should not invalidate the session. +// +// Only checks standard OAuth2 error codes (RFC 6749 Section 5.2). +// Does NOT check error_description strings as they are optional and provider-specific. +func isFatalRefreshError(err error) bool { + if err == nil { + return false + } + + errStr := err.Error() + // Only check standard OAuth2 error codes (RFC 6749 Section 5.2) + // Do NOT check error_description strings as they are optional and provider-specific + fatalErrors := []string{ + "invalid_grant", // refresh token revoked, expired, or session terminated + "invalid_client", // client credentials no longer valid + } + + for _, fe := range fatalErrors { + if strings.Contains(errStr, fe) { + return true + } + } + return false +} + // StoredSessionLoaderOptions contains all of the requirements to construct // a stored session loader. // All options must be provided. @@ -188,9 +217,25 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req // We are holding the lock and the session needs a refresh logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) if err := s.refreshSession(rw, req, session); err != nil { - // If a preemptive refresh fails, we still keep the session - // if validateSession succeeds. logger.Errorf("Unable to refresh session: %v", err) + + // Check if this is a fatal error that indicates the session is revoked + // or no longer valid at the provider level + if isFatalRefreshError(err) { + logger.Printf("Fatal refresh error detected (session revoked or invalid), clearing session for user: %s", session.User) + + // Clear the session from storage (Redis) and remove the cookie + clearErr := s.store.Clear(rw, req) + if clearErr != nil { + logger.Errorf("Error clearing session: %v", clearErr) + } + + // Return error immediately to force re-authentication + return fmt.Errorf("session invalidated due to fatal refresh error: %w", err) + } + + // For non-fatal errors (network issues, timeouts), keep the session + // and let validateSession determine if it's still usable } // Validate all sessions after any Redeem/Refresh operation (fail or success) diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index d8e78f2f..c913a4ef 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "sync" + "testing" "time" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" @@ -801,3 +802,53 @@ func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro func (f *fakeSessionStore) VerifyConnection(_ context.Context) error { return nil } + +// TestIsFatalRefreshError tests the isFatalRefreshError function to ensure +// it correctly identifies fatal OAuth2 errors that should invalidate a session. +func TestIsFatalRefreshError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "invalid_grant error", + err: fmt.Errorf("failed to get token: oauth2: \"invalid_grant\" \"Session not active\""), + expected: true, + }, + { + name: "invalid_client error", + err: fmt.Errorf("invalid_client: client not found"), + expected: true, + }, + { + name: "network timeout - not fatal", + err: fmt.Errorf("Post \"https://keycloak/token\": dial tcp: connect: connection refused"), + expected: false, + }, + { + name: "server error - not fatal", + err: fmt.Errorf("unexpected status code 500"), + expected: false, + }, + { + name: "generic refresh error - not fatal", + err: fmt.Errorf("error refreshing tokens: context deadline exceeded"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isFatalRefreshError(tt.err) + if result != tt.expected { + t.Errorf("isFatalRefreshError(%v) = %v, want %v", tt.err, result, tt.expected) + } + }) + } +}