Simplify exec SSH reuse with lazy cache

This commit is contained in:
Fedor Korotkov 2026-05-06 13:34:39 -04:00
parent 425d372fa5
commit cfb507e4ac
6 changed files with 123 additions and 231 deletions

View File

@ -189,8 +189,7 @@ func (controller *Controller) newSSHExecSession(
sessionContext, sessionContextCancel := context.WithCancel(context.Background())
type sshExecAttempt struct {
lease *execSSHTransportLease
exec sshExecRunner
exec sshExecRunner
}
transportKey := execSSHTransportKey{
@ -206,7 +205,7 @@ func (controller *Controller) newSSHExecSession(
retry.Attempts(0),
retry.LastErrorOnly(true),
).Do(func() (sshExecAttempt, error) {
lease, err := controller.execSSHPool.acquire(waitContext, transportKey, func() (execSSHTransport, error) {
transport, reused, err := controller.execSSHCache.getOrCreate(transportKey, func() (execSSHTransport, error) {
portForwardConn, err := controller.portForwardConnection(
context.Background(),
waitContext,
@ -229,17 +228,17 @@ func (controller *Controller) newSSHExecSession(
return sshExecAttempt{}, err
}
exec, err := lease.transport().NewExec(sshexec.Options{
exec, err := transport.NewExec(sshexec.Options{
Interactive: spec.interactive,
TTY: spec.tty,
Rows: spec.rows,
Cols: spec.cols,
})
if err != nil {
lease.release()
controller.execSSHCache.discard(transportKey, transport)
err = fmt.Errorf("failed to create SSH session for a VM: %w", err)
if lease.reused {
if reused {
return sshExecAttempt{}, retry.Unrecoverable(err)
}
@ -247,8 +246,7 @@ func (controller *Controller) newSSHExecSession(
}
return sshExecAttempt{
lease: lease,
exec: exec,
exec: exec,
}, nil
})
if err != nil {
@ -264,7 +262,6 @@ func (controller *Controller) newSSHExecSession(
spec,
runCommand,
attempt.exec,
attempt.lease.release,
registry,
controller.execSessionExitTTL,
policy,

View File

@ -66,7 +66,7 @@ type Controller struct {
sshNoClientAuth bool
sshServer *sshserver.SSHServer
execSessions *execSessionRegistry
execSSHPool *execSSHTransportPool
execSSHCache *execSSHTransportCache
single singleflight.Group
@ -81,7 +81,7 @@ func New(opts ...Option) (*Controller, error) {
execSessionExitTTL: 10 * time.Minute,
pingInterval: 30 * time.Second,
execSessions: newExecSessionRegistry(),
execSSHPool: newExecSSHTransportPool(),
execSSHCache: newExecSSHTransportCache(),
single: singleflight.Group{},
}
@ -315,7 +315,7 @@ func (controller *Controller) Run(ctx context.Context) error {
<-ctx.Done()
controller.execSessions.closeAll()
controller.execSSHPool.closeAll()
controller.execSSHCache.closeAll()
if err := controller.httpServer.Shutdown(ctx); err != nil {
controller.logger.Errorf("failed to cleanly shutdown the HTTP server: %v", err)

View File

@ -328,7 +328,6 @@ type execSession struct {
spec execSessionSpec
command string
exec sshExecRunner
release func()
registry *execSessionRegistry
exitTTL time.Duration
policy execSessionPolicy
@ -347,7 +346,6 @@ type execSession struct {
expiryTimer *time.Timer
startOnce sync.Once
closeOnce sync.Once
done chan struct{}
doneOnce sync.Once
}
@ -356,7 +354,6 @@ func newExecSession(
key execSessionKey,
command string,
exec sshExecRunner,
release func(),
registry *execSessionRegistry,
exitTTL time.Duration,
policy execSessionPolicy,
@ -370,7 +367,6 @@ func newExecSession(
execSessionSpec{command: command},
command,
exec,
release,
registry,
exitTTL,
policy,
@ -384,7 +380,6 @@ func newExecSessionWithContextAndSpec(
spec execSessionSpec,
command string,
exec sshExecRunner,
release func(),
registry *execSessionRegistry,
exitTTL time.Duration,
policy execSessionPolicy,
@ -398,7 +393,6 @@ func newExecSessionWithContextAndSpec(
spec: spec.clone(),
command: command,
exec: exec,
release: release,
registry: registry,
exitTTL: exitTTL,
policy: policy,
@ -574,7 +568,7 @@ func (session *execSession) close() {
closeSubscribers(subscribers)
session.cancel()
session.closeCommandResources()
_ = session.exec.Close()
if session.registry != nil {
session.registry.remove(session.key, session)
}
@ -658,7 +652,7 @@ func (session *execSession) markFinished() {
close(session.done)
})
session.closeCommandResources()
_ = session.exec.Close()
if shouldClose {
session.close()
@ -692,15 +686,6 @@ func (session *execSession) dropSubscriber(subscriber *execSessionSubscriber) {
session.detachLocked(subscriber)
}
func (session *execSession) closeCommandResources() {
session.closeOnce.Do(func() {
_ = session.exec.Close()
if session.release != nil {
session.release()
}
})
}
func cloneExecFrame(frame *execstream.Frame) *execstream.Frame {
if frame == nil {
return nil

View File

@ -126,7 +126,6 @@ func TestExecSessionStartRunsCommandOnlyOnce(t *testing.T) {
},
},
nil,
nil,
time.Minute,
reconnectableExecSessionPolicy,
)
@ -387,13 +386,9 @@ func TestExecSessionFinishKeepsReconnectableSubscriberOpen(t *testing.T) {
require.EqualValues(t, 2, noMoreHistory.Watermark)
}
func TestExecSessionFinishReleasesTransportWhileRetainingHistory(t *testing.T) {
func TestExecSessionFinishClosesCommandWhileRetainingHistory(t *testing.T) {
registry := newExecSessionRegistry()
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
var releaseCalls atomic.Int32
session.release = func() {
releaseCalls.Add(1)
}
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")})
session.recordFrame(&execstream.Frame{
@ -403,7 +398,6 @@ func TestExecSessionFinishReleasesTransportWhileRetainingHistory(t *testing.T) {
session.markFinished()
require.False(t, session.closed)
require.EqualValues(t, 1, releaseCalls.Load())
require.EqualValues(t, 1, session.exec.(*fakeExec).closeCalls.Load())
subscriber, err := session.attach()

View File

@ -1,7 +1,6 @@
package controller
import (
"context"
"sync"
"github.com/cirruslabs/orchard/internal/controller/sshexec"
@ -30,153 +29,63 @@ type execSSHTransportKey struct {
restartCount uint64
}
type execSSHTransportEntry struct {
key execSSHTransportKey
transport execSSHTransport
refs int
closed bool
type execSSHTransportCache struct {
mu sync.Mutex
entries map[execSSHTransportKey]execSSHTransport
}
type execSSHTransportCreation struct {
done chan struct{}
err error
}
type execSSHTransportPool struct {
mu sync.Mutex
entries map[execSSHTransportKey]*execSSHTransportEntry
creating map[execSSHTransportKey]*execSSHTransportCreation
}
func newExecSSHTransportPool() *execSSHTransportPool {
return &execSSHTransportPool{
entries: map[execSSHTransportKey]*execSSHTransportEntry{},
creating: map[execSSHTransportKey]*execSSHTransportCreation{},
func newExecSSHTransportCache() *execSSHTransportCache {
return &execSSHTransportCache{
entries: map[execSSHTransportKey]execSSHTransport{},
}
}
func (pool *execSSHTransportPool) acquire(
ctx context.Context,
func (cache *execSSHTransportCache) getOrCreate(
key execSSHTransportKey,
create func() (execSSHTransport, error),
) (*execSSHTransportLease, error) {
for {
pool.mu.Lock()
) (execSSHTransport, bool, error) {
cache.mu.Lock()
defer cache.mu.Unlock()
if entry, ok := pool.entries[key]; ok {
entry.refs++
pool.mu.Unlock()
return &execSSHTransportLease{
pool: pool,
entry: entry,
reused: true,
}, nil
}
if creation, ok := pool.creating[key]; ok {
pool.mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-creation.done:
if creation.err != nil {
return nil, creation.err
}
}
continue
}
creation := &execSSHTransportCreation{done: make(chan struct{})}
pool.creating[key] = creation
pool.mu.Unlock()
transport, err := create()
pool.mu.Lock()
delete(pool.creating, key)
creation.err = err
var entry *execSSHTransportEntry
if err == nil {
entry = &execSSHTransportEntry{
key: key,
transport: transport,
refs: 1,
}
pool.entries[key] = entry
}
close(creation.done)
pool.mu.Unlock()
if err != nil {
return nil, err
}
return &execSSHTransportLease{
pool: pool,
entry: entry,
}, nil
if transport, ok := cache.entries[key]; ok {
return transport, true, nil
}
transport, err := create()
if err != nil {
return nil, false, err
}
cache.entries[key] = transport
return transport, false, nil
}
func (pool *execSSHTransportPool) release(entry *execSSHTransportEntry) {
func (cache *execSSHTransportCache) discard(key execSSHTransportKey, expected execSSHTransport) {
var transport execSSHTransport
pool.mu.Lock()
if entry.closed {
pool.mu.Unlock()
return
cache.mu.Lock()
if cache.entries[key] == expected {
transport = expected
delete(cache.entries, key)
}
entry.refs--
if entry.refs == 0 {
entry.closed = true
if pool.entries[entry.key] == entry {
delete(pool.entries, entry.key)
}
transport = entry.transport
}
pool.mu.Unlock()
cache.mu.Unlock()
if transport != nil {
_ = transport.Close()
}
}
func (pool *execSSHTransportPool) closeAll() {
pool.mu.Lock()
entries := make([]*execSSHTransportEntry, 0, len(pool.entries))
for key, entry := range pool.entries {
entry.closed = true
delete(pool.entries, key)
entries = append(entries, entry)
func (cache *execSSHTransportCache) closeAll() {
cache.mu.Lock()
transports := make([]execSSHTransport, 0, len(cache.entries))
for key, transport := range cache.entries {
delete(cache.entries, key)
transports = append(transports, transport)
}
pool.mu.Unlock()
cache.mu.Unlock()
for _, entry := range entries {
_ = entry.transport.Close()
for _, transport := range transports {
_ = transport.Close()
}
}
type execSSHTransportLease struct {
pool *execSSHTransportPool
entry *execSSHTransportEntry
reused bool
releaseOnce sync.Once
}
func (lease *execSSHTransportLease) transport() execSSHTransport {
return lease.entry.transport
}
func (lease *execSSHTransportLease) release() {
lease.releaseOnce.Do(func() {
lease.pool.release(lease.entry)
})
}

View File

@ -1,7 +1,6 @@
package controller
import (
"context"
"errors"
"sync"
"sync/atomic"
@ -31,8 +30,8 @@ func (transport *fakeExecSSHTransport) Close() error {
return nil
}
func TestExecSSHTransportPoolConcurrentAcquireReusesOneTransport(t *testing.T) {
pool := newExecSSHTransportPool()
func TestExecSSHTransportCacheConcurrentGetOrCreateReusesOneTransport(t *testing.T) {
cache := newExecSSHTransportCache()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
createStarted := make(chan struct{})
@ -49,24 +48,22 @@ func TestExecSSHTransportPoolConcurrentAcquireReusesOneTransport(t *testing.T) {
return transport, nil
}
const leasesCount = 16
leases := make([]*execSSHTransportLease, leasesCount)
errCh := make(chan error, leasesCount)
const callersCount = 16
transports := make([]execSSHTransport, callersCount)
reused := make([]bool, callersCount)
errCh := make(chan error, callersCount)
var wg sync.WaitGroup
wg.Add(leasesCount)
for i := range leasesCount {
wg.Add(callersCount)
for i := range callersCount {
go func() {
defer wg.Done()
lease, err := pool.acquire(context.Background(), key, create)
var err error
transports[i], reused[i], err = cache.getOrCreate(key, create)
if err != nil {
errCh <- err
return
}
leases[i] = lease
}()
}
@ -80,37 +77,14 @@ func TestExecSSHTransportPoolConcurrentAcquireReusesOneTransport(t *testing.T) {
}
require.EqualValues(t, 1, createCalls.Load())
for _, lease := range leases {
require.NotNil(t, lease)
require.Same(t, transport, lease.transport())
lease.release()
for _, cachedTransport := range transports {
require.Same(t, transport, cachedTransport)
}
require.Equal(t, callersCount-1, countTrue(reused))
}
func TestExecSSHTransportPoolClosesOnLastRelease(t *testing.T) {
pool := newExecSSHTransportPool()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
transport := &fakeExecSSHTransport{}
create := func() (execSSHTransport, error) {
return transport, nil
}
firstLease, err := pool.acquire(context.Background(), key, create)
require.NoError(t, err)
secondLease, err := pool.acquire(context.Background(), key, create)
require.NoError(t, err)
firstLease.release()
require.EqualValues(t, 0, transport.closeCalls.Load())
secondLease.release()
require.EqualValues(t, 1, transport.closeCalls.Load())
require.Empty(t, pool.entries)
}
func TestExecSSHTransportPoolSeparatesVMIncarnations(t *testing.T) {
pool := newExecSSHTransportPool()
func TestExecSSHTransportCacheSeparatesVMIncarnations(t *testing.T) {
cache := newExecSSHTransportCache()
var createCalls atomic.Int32
create := func() (execSSHTransport, error) {
@ -119,28 +93,59 @@ func TestExecSSHTransportPoolSeparatesVMIncarnations(t *testing.T) {
return &fakeExecSSHTransport{}, nil
}
firstLease, err := pool.acquire(context.Background(),
firstTransport, _, err := cache.getOrCreate(
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}, create)
require.NoError(t, err)
secondLease, err := pool.acquire(context.Background(),
secondTransport, _, err := cache.getOrCreate(
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 2}, create)
require.NoError(t, err)
defer firstLease.release()
defer secondLease.release()
require.EqualValues(t, 2, createCalls.Load())
require.NotSame(t, firstLease.transport(), secondLease.transport())
require.NotSame(t, firstTransport, secondTransport)
}
func TestExecSSHTransportPoolKeepsSharedTransportAfterSessionCreationFailure(t *testing.T) {
pool := newExecSSHTransportPool()
func TestExecSSHTransportCacheDiscardClosesExpectedTransport(t *testing.T) {
cache := newExecSSHTransportCache()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
transport := &fakeExecSSHTransport{}
cachedTransport, _, err := cache.getOrCreate(key, func() (execSSHTransport, error) {
return transport, nil
})
require.NoError(t, err)
cache.discard(key, cachedTransport)
require.EqualValues(t, 1, transport.closeCalls.Load())
require.Empty(t, cache.entries)
}
func TestExecSSHTransportCacheDiscardIgnoresReplacedTransport(t *testing.T) {
cache := newExecSSHTransportCache()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
originalTransport := &fakeExecSSHTransport{}
replacementTransport := &fakeExecSSHTransport{}
cachedTransport, _, err := cache.getOrCreate(key, func() (execSSHTransport, error) {
return originalTransport, nil
})
require.NoError(t, err)
cache.entries[key] = replacementTransport
cache.discard(key, cachedTransport)
require.EqualValues(t, 0, originalTransport.closeCalls.Load())
require.Same(t, replacementTransport, cache.entries[key])
}
func TestExecSSHTransportCacheKeepsTransportAcrossExecs(t *testing.T) {
cache := newExecSSHTransportCache()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
var createCalls atomic.Int32
transport := &fakeExecSSHTransport{
newExec: func(sshexec.Options) (sshExecRunner, error) {
return nil, errors.New("failed to open channel")
},
}
var createCalls atomic.Int32
create := func() (execSSHTransport, error) {
createCalls.Add(1)
@ -148,39 +153,41 @@ func TestExecSSHTransportPoolKeepsSharedTransportAfterSessionCreationFailure(t *
return transport, nil
}
activeLease, err := pool.acquire(context.Background(), key, create)
firstTransport, reused, err := cache.getOrCreate(key, create)
require.NoError(t, err)
failedLease, err := pool.acquire(context.Background(), key, create)
require.False(t, reused)
secondTransport, reused, err := cache.getOrCreate(key, create)
require.NoError(t, err)
require.True(t, failedLease.reused)
_, err = failedLease.transport().NewExec(sshexec.Options{})
require.ErrorContains(t, err, "failed to open channel")
failedLease.release()
require.True(t, reused)
require.Same(t, firstTransport, secondTransport)
require.EqualValues(t, 1, createCalls.Load())
require.EqualValues(t, 0, transport.closeCalls.Load())
require.Len(t, pool.entries, 1)
activeLease.release()
require.EqualValues(t, 1, transport.closeCalls.Load())
}
func TestExecSSHTransportPoolCloseAllClosesActiveTransports(t *testing.T) {
pool := newExecSSHTransportPool()
func TestExecSSHTransportCacheCloseAllClosesTransports(t *testing.T) {
cache := newExecSSHTransportCache()
transport := &fakeExecSSHTransport{}
lease, err := pool.acquire(context.Background(),
_, _, err := cache.getOrCreate(
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1},
func() (execSSHTransport, error) {
return transport, nil
})
require.NoError(t, err)
pool.closeAll()
require.EqualValues(t, 1, transport.closeCalls.Load())
require.Empty(t, pool.entries)
lease.release()
cache.closeAll()
require.EqualValues(t, 1, transport.closeCalls.Load())
require.Empty(t, cache.entries)
}
func countTrue(values []bool) int {
count := 0
for _, value := range values {
if value {
count++
}
}
return count
}