From cfb507e4ac52814e3eaa0e98a2b60609aa8f7d37 Mon Sep 17 00:00:00 2001 From: Fedor Korotkov Date: Wed, 6 May 2026 13:34:39 -0400 Subject: [PATCH] Simplify exec SSH reuse with lazy cache --- internal/controller/api_vms_exec.go | 15 +- internal/controller/controller.go | 6 +- internal/controller/exec_sessions.go | 19 +-- internal/controller/exec_sessions_test.go | 8 +- internal/controller/exec_ssh_pool.go | 163 +++++----------------- internal/controller/exec_ssh_pool_test.go | 143 ++++++++++--------- 6 files changed, 123 insertions(+), 231 deletions(-) diff --git a/internal/controller/api_vms_exec.go b/internal/controller/api_vms_exec.go index f4cd663..dc0288e 100644 --- a/internal/controller/api_vms_exec.go +++ b/internal/controller/api_vms_exec.go @@ -189,8 +189,7 @@ func (controller *Controller) newSSHExecSession( sessionContext, sessionContextCancel := context.WithCancel(context.Background()) type sshExecAttempt struct { - lease *execSSHTransportLease - exec sshExecRunner + exec sshExecRunner } transportKey := execSSHTransportKey{ @@ -206,7 +205,7 @@ func (controller *Controller) newSSHExecSession( retry.Attempts(0), retry.LastErrorOnly(true), ).Do(func() (sshExecAttempt, error) { - lease, err := controller.execSSHPool.acquire(waitContext, transportKey, func() (execSSHTransport, error) { + transport, reused, err := controller.execSSHCache.getOrCreate(transportKey, func() (execSSHTransport, error) { portForwardConn, err := controller.portForwardConnection( context.Background(), waitContext, @@ -229,17 +228,17 @@ func (controller *Controller) newSSHExecSession( return sshExecAttempt{}, err } - exec, err := lease.transport().NewExec(sshexec.Options{ + exec, err := transport.NewExec(sshexec.Options{ Interactive: spec.interactive, TTY: spec.tty, Rows: spec.rows, Cols: spec.cols, }) if err != nil { - lease.release() + controller.execSSHCache.discard(transportKey, transport) err = fmt.Errorf("failed to create SSH session for a VM: %w", err) - if lease.reused { + if reused { return sshExecAttempt{}, retry.Unrecoverable(err) } @@ -247,8 +246,7 @@ func (controller *Controller) newSSHExecSession( } return sshExecAttempt{ - lease: lease, - exec: exec, + exec: exec, }, nil }) if err != nil { @@ -264,7 +262,6 @@ func (controller *Controller) newSSHExecSession( spec, runCommand, attempt.exec, - attempt.lease.release, registry, controller.execSessionExitTTL, policy, diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 0cdf04e..c0efcd9 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -66,7 +66,7 @@ type Controller struct { sshNoClientAuth bool sshServer *sshserver.SSHServer execSessions *execSessionRegistry - execSSHPool *execSSHTransportPool + execSSHCache *execSSHTransportCache single singleflight.Group @@ -81,7 +81,7 @@ func New(opts ...Option) (*Controller, error) { execSessionExitTTL: 10 * time.Minute, pingInterval: 30 * time.Second, execSessions: newExecSessionRegistry(), - execSSHPool: newExecSSHTransportPool(), + execSSHCache: newExecSSHTransportCache(), single: singleflight.Group{}, } @@ -315,7 +315,7 @@ func (controller *Controller) Run(ctx context.Context) error { <-ctx.Done() controller.execSessions.closeAll() - controller.execSSHPool.closeAll() + controller.execSSHCache.closeAll() if err := controller.httpServer.Shutdown(ctx); err != nil { controller.logger.Errorf("failed to cleanly shutdown the HTTP server: %v", err) diff --git a/internal/controller/exec_sessions.go b/internal/controller/exec_sessions.go index f55aff8..5c02106 100644 --- a/internal/controller/exec_sessions.go +++ b/internal/controller/exec_sessions.go @@ -328,7 +328,6 @@ type execSession struct { spec execSessionSpec command string exec sshExecRunner - release func() registry *execSessionRegistry exitTTL time.Duration policy execSessionPolicy @@ -347,7 +346,6 @@ type execSession struct { expiryTimer *time.Timer startOnce sync.Once - closeOnce sync.Once done chan struct{} doneOnce sync.Once } @@ -356,7 +354,6 @@ func newExecSession( key execSessionKey, command string, exec sshExecRunner, - release func(), registry *execSessionRegistry, exitTTL time.Duration, policy execSessionPolicy, @@ -370,7 +367,6 @@ func newExecSession( execSessionSpec{command: command}, command, exec, - release, registry, exitTTL, policy, @@ -384,7 +380,6 @@ func newExecSessionWithContextAndSpec( spec execSessionSpec, command string, exec sshExecRunner, - release func(), registry *execSessionRegistry, exitTTL time.Duration, policy execSessionPolicy, @@ -398,7 +393,6 @@ func newExecSessionWithContextAndSpec( spec: spec.clone(), command: command, exec: exec, - release: release, registry: registry, exitTTL: exitTTL, policy: policy, @@ -574,7 +568,7 @@ func (session *execSession) close() { closeSubscribers(subscribers) session.cancel() - session.closeCommandResources() + _ = session.exec.Close() if session.registry != nil { session.registry.remove(session.key, session) } @@ -658,7 +652,7 @@ func (session *execSession) markFinished() { close(session.done) }) - session.closeCommandResources() + _ = session.exec.Close() if shouldClose { session.close() @@ -692,15 +686,6 @@ func (session *execSession) dropSubscriber(subscriber *execSessionSubscriber) { session.detachLocked(subscriber) } -func (session *execSession) closeCommandResources() { - session.closeOnce.Do(func() { - _ = session.exec.Close() - if session.release != nil { - session.release() - } - }) -} - func cloneExecFrame(frame *execstream.Frame) *execstream.Frame { if frame == nil { return nil diff --git a/internal/controller/exec_sessions_test.go b/internal/controller/exec_sessions_test.go index 78aded6..d51474a 100644 --- a/internal/controller/exec_sessions_test.go +++ b/internal/controller/exec_sessions_test.go @@ -126,7 +126,6 @@ func TestExecSessionStartRunsCommandOnlyOnce(t *testing.T) { }, }, nil, - nil, time.Minute, reconnectableExecSessionPolicy, ) @@ -387,13 +386,9 @@ func TestExecSessionFinishKeepsReconnectableSubscriberOpen(t *testing.T) { require.EqualValues(t, 2, noMoreHistory.Watermark) } -func TestExecSessionFinishReleasesTransportWhileRetainingHistory(t *testing.T) { +func TestExecSessionFinishClosesCommandWhileRetainingHistory(t *testing.T) { registry := newExecSessionRegistry() session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry) - var releaseCalls atomic.Int32 - session.release = func() { - releaseCalls.Add(1) - } session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")}) session.recordFrame(&execstream.Frame{ @@ -403,7 +398,6 @@ func TestExecSessionFinishReleasesTransportWhileRetainingHistory(t *testing.T) { session.markFinished() require.False(t, session.closed) - require.EqualValues(t, 1, releaseCalls.Load()) require.EqualValues(t, 1, session.exec.(*fakeExec).closeCalls.Load()) subscriber, err := session.attach() diff --git a/internal/controller/exec_ssh_pool.go b/internal/controller/exec_ssh_pool.go index 777e5bb..a0e4457 100644 --- a/internal/controller/exec_ssh_pool.go +++ b/internal/controller/exec_ssh_pool.go @@ -1,7 +1,6 @@ package controller import ( - "context" "sync" "github.com/cirruslabs/orchard/internal/controller/sshexec" @@ -30,153 +29,63 @@ type execSSHTransportKey struct { restartCount uint64 } -type execSSHTransportEntry struct { - key execSSHTransportKey - transport execSSHTransport - refs int - closed bool +type execSSHTransportCache struct { + mu sync.Mutex + entries map[execSSHTransportKey]execSSHTransport } -type execSSHTransportCreation struct { - done chan struct{} - err error -} - -type execSSHTransportPool struct { - mu sync.Mutex - entries map[execSSHTransportKey]*execSSHTransportEntry - creating map[execSSHTransportKey]*execSSHTransportCreation -} - -func newExecSSHTransportPool() *execSSHTransportPool { - return &execSSHTransportPool{ - entries: map[execSSHTransportKey]*execSSHTransportEntry{}, - creating: map[execSSHTransportKey]*execSSHTransportCreation{}, +func newExecSSHTransportCache() *execSSHTransportCache { + return &execSSHTransportCache{ + entries: map[execSSHTransportKey]execSSHTransport{}, } } -func (pool *execSSHTransportPool) acquire( - ctx context.Context, +func (cache *execSSHTransportCache) getOrCreate( key execSSHTransportKey, create func() (execSSHTransport, error), -) (*execSSHTransportLease, error) { - for { - pool.mu.Lock() +) (execSSHTransport, bool, error) { + cache.mu.Lock() + defer cache.mu.Unlock() - if entry, ok := pool.entries[key]; ok { - entry.refs++ - pool.mu.Unlock() - - return &execSSHTransportLease{ - pool: pool, - entry: entry, - reused: true, - }, nil - } - - if creation, ok := pool.creating[key]; ok { - pool.mu.Unlock() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-creation.done: - if creation.err != nil { - return nil, creation.err - } - } - - continue - } - - creation := &execSSHTransportCreation{done: make(chan struct{})} - pool.creating[key] = creation - pool.mu.Unlock() - - transport, err := create() - - pool.mu.Lock() - delete(pool.creating, key) - creation.err = err - - var entry *execSSHTransportEntry - if err == nil { - entry = &execSSHTransportEntry{ - key: key, - transport: transport, - refs: 1, - } - pool.entries[key] = entry - } - - close(creation.done) - pool.mu.Unlock() - - if err != nil { - return nil, err - } - - return &execSSHTransportLease{ - pool: pool, - entry: entry, - }, nil + if transport, ok := cache.entries[key]; ok { + return transport, true, nil } + + transport, err := create() + if err != nil { + return nil, false, err + } + + cache.entries[key] = transport + + return transport, false, nil } -func (pool *execSSHTransportPool) release(entry *execSSHTransportEntry) { +func (cache *execSSHTransportCache) discard(key execSSHTransportKey, expected execSSHTransport) { var transport execSSHTransport - pool.mu.Lock() - if entry.closed { - pool.mu.Unlock() - - return + cache.mu.Lock() + if cache.entries[key] == expected { + transport = expected + delete(cache.entries, key) } - - entry.refs-- - if entry.refs == 0 { - entry.closed = true - if pool.entries[entry.key] == entry { - delete(pool.entries, entry.key) - } - transport = entry.transport - } - pool.mu.Unlock() + cache.mu.Unlock() if transport != nil { _ = transport.Close() } } -func (pool *execSSHTransportPool) closeAll() { - pool.mu.Lock() - entries := make([]*execSSHTransportEntry, 0, len(pool.entries)) - for key, entry := range pool.entries { - entry.closed = true - delete(pool.entries, key) - entries = append(entries, entry) +func (cache *execSSHTransportCache) closeAll() { + cache.mu.Lock() + transports := make([]execSSHTransport, 0, len(cache.entries)) + for key, transport := range cache.entries { + delete(cache.entries, key) + transports = append(transports, transport) } - pool.mu.Unlock() + cache.mu.Unlock() - for _, entry := range entries { - _ = entry.transport.Close() + for _, transport := range transports { + _ = transport.Close() } } - -type execSSHTransportLease struct { - pool *execSSHTransportPool - entry *execSSHTransportEntry - reused bool - - releaseOnce sync.Once -} - -func (lease *execSSHTransportLease) transport() execSSHTransport { - return lease.entry.transport -} - -func (lease *execSSHTransportLease) release() { - lease.releaseOnce.Do(func() { - lease.pool.release(lease.entry) - }) -} diff --git a/internal/controller/exec_ssh_pool_test.go b/internal/controller/exec_ssh_pool_test.go index 47cef9c..596ff1c 100644 --- a/internal/controller/exec_ssh_pool_test.go +++ b/internal/controller/exec_ssh_pool_test.go @@ -1,7 +1,6 @@ package controller import ( - "context" "errors" "sync" "sync/atomic" @@ -31,8 +30,8 @@ func (transport *fakeExecSSHTransport) Close() error { return nil } -func TestExecSSHTransportPoolConcurrentAcquireReusesOneTransport(t *testing.T) { - pool := newExecSSHTransportPool() +func TestExecSSHTransportCacheConcurrentGetOrCreateReusesOneTransport(t *testing.T) { + cache := newExecSSHTransportCache() key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1} createStarted := make(chan struct{}) @@ -49,24 +48,22 @@ func TestExecSSHTransportPoolConcurrentAcquireReusesOneTransport(t *testing.T) { return transport, nil } - const leasesCount = 16 - leases := make([]*execSSHTransportLease, leasesCount) - errCh := make(chan error, leasesCount) + const callersCount = 16 + transports := make([]execSSHTransport, callersCount) + reused := make([]bool, callersCount) + errCh := make(chan error, callersCount) var wg sync.WaitGroup - wg.Add(leasesCount) - for i := range leasesCount { + wg.Add(callersCount) + for i := range callersCount { go func() { defer wg.Done() - lease, err := pool.acquire(context.Background(), key, create) + var err error + transports[i], reused[i], err = cache.getOrCreate(key, create) if err != nil { errCh <- err - - return } - - leases[i] = lease }() } @@ -80,37 +77,14 @@ func TestExecSSHTransportPoolConcurrentAcquireReusesOneTransport(t *testing.T) { } require.EqualValues(t, 1, createCalls.Load()) - for _, lease := range leases { - require.NotNil(t, lease) - require.Same(t, transport, lease.transport()) - lease.release() + for _, cachedTransport := range transports { + require.Same(t, transport, cachedTransport) } + require.Equal(t, callersCount-1, countTrue(reused)) } -func TestExecSSHTransportPoolClosesOnLastRelease(t *testing.T) { - pool := newExecSSHTransportPool() - key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1} - transport := &fakeExecSSHTransport{} - - create := func() (execSSHTransport, error) { - return transport, nil - } - - firstLease, err := pool.acquire(context.Background(), key, create) - require.NoError(t, err) - secondLease, err := pool.acquire(context.Background(), key, create) - require.NoError(t, err) - - firstLease.release() - require.EqualValues(t, 0, transport.closeCalls.Load()) - - secondLease.release() - require.EqualValues(t, 1, transport.closeCalls.Load()) - require.Empty(t, pool.entries) -} - -func TestExecSSHTransportPoolSeparatesVMIncarnations(t *testing.T) { - pool := newExecSSHTransportPool() +func TestExecSSHTransportCacheSeparatesVMIncarnations(t *testing.T) { + cache := newExecSSHTransportCache() var createCalls atomic.Int32 create := func() (execSSHTransport, error) { @@ -119,28 +93,59 @@ func TestExecSSHTransportPoolSeparatesVMIncarnations(t *testing.T) { return &fakeExecSSHTransport{}, nil } - firstLease, err := pool.acquire(context.Background(), + firstTransport, _, err := cache.getOrCreate( execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}, create) require.NoError(t, err) - secondLease, err := pool.acquire(context.Background(), + secondTransport, _, err := cache.getOrCreate( execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 2}, create) require.NoError(t, err) - defer firstLease.release() - defer secondLease.release() require.EqualValues(t, 2, createCalls.Load()) - require.NotSame(t, firstLease.transport(), secondLease.transport()) + require.NotSame(t, firstTransport, secondTransport) } -func TestExecSSHTransportPoolKeepsSharedTransportAfterSessionCreationFailure(t *testing.T) { - pool := newExecSSHTransportPool() +func TestExecSSHTransportCacheDiscardClosesExpectedTransport(t *testing.T) { + cache := newExecSSHTransportCache() + key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1} + transport := &fakeExecSSHTransport{} + + cachedTransport, _, err := cache.getOrCreate(key, func() (execSSHTransport, error) { + return transport, nil + }) + require.NoError(t, err) + + cache.discard(key, cachedTransport) + require.EqualValues(t, 1, transport.closeCalls.Load()) + require.Empty(t, cache.entries) +} + +func TestExecSSHTransportCacheDiscardIgnoresReplacedTransport(t *testing.T) { + cache := newExecSSHTransportCache() + key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1} + originalTransport := &fakeExecSSHTransport{} + replacementTransport := &fakeExecSSHTransport{} + + cachedTransport, _, err := cache.getOrCreate(key, func() (execSSHTransport, error) { + return originalTransport, nil + }) + require.NoError(t, err) + + cache.entries[key] = replacementTransport + cache.discard(key, cachedTransport) + + require.EqualValues(t, 0, originalTransport.closeCalls.Load()) + require.Same(t, replacementTransport, cache.entries[key]) +} + +func TestExecSSHTransportCacheKeepsTransportAcrossExecs(t *testing.T) { + cache := newExecSSHTransportCache() key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1} - var createCalls atomic.Int32 transport := &fakeExecSSHTransport{ newExec: func(sshexec.Options) (sshExecRunner, error) { return nil, errors.New("failed to open channel") }, } + var createCalls atomic.Int32 create := func() (execSSHTransport, error) { createCalls.Add(1) @@ -148,39 +153,41 @@ func TestExecSSHTransportPoolKeepsSharedTransportAfterSessionCreationFailure(t * return transport, nil } - activeLease, err := pool.acquire(context.Background(), key, create) + firstTransport, reused, err := cache.getOrCreate(key, create) require.NoError(t, err) - failedLease, err := pool.acquire(context.Background(), key, create) + require.False(t, reused) + secondTransport, reused, err := cache.getOrCreate(key, create) require.NoError(t, err) - require.True(t, failedLease.reused) - - _, err = failedLease.transport().NewExec(sshexec.Options{}) - require.ErrorContains(t, err, "failed to open channel") - failedLease.release() + require.True(t, reused) + require.Same(t, firstTransport, secondTransport) require.EqualValues(t, 1, createCalls.Load()) require.EqualValues(t, 0, transport.closeCalls.Load()) - require.Len(t, pool.entries, 1) - - activeLease.release() - require.EqualValues(t, 1, transport.closeCalls.Load()) } -func TestExecSSHTransportPoolCloseAllClosesActiveTransports(t *testing.T) { - pool := newExecSSHTransportPool() +func TestExecSSHTransportCacheCloseAllClosesTransports(t *testing.T) { + cache := newExecSSHTransportCache() transport := &fakeExecSSHTransport{} - lease, err := pool.acquire(context.Background(), + _, _, err := cache.getOrCreate( execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}, func() (execSSHTransport, error) { return transport, nil }) require.NoError(t, err) - pool.closeAll() - require.EqualValues(t, 1, transport.closeCalls.Load()) - require.Empty(t, pool.entries) - - lease.release() + cache.closeAll() require.EqualValues(t, 1, transport.closeCalls.Load()) + require.Empty(t, cache.entries) +} + +func countTrue(values []bool) int { + count := 0 + for _, value := range values { + if value { + count++ + } + } + + return count }