orchard/internal/controller/exec_ssh_pool_test.go

194 lines
4.9 KiB
Go

package controller
import (
"errors"
"sync"
"sync/atomic"
"testing"
"github.com/cirruslabs/orchard/internal/controller/sshexec"
"github.com/stretchr/testify/require"
)
type fakeExecSSHTransport struct {
newExec func(sshexec.Options) (sshExecRunner, error)
closeCalls atomic.Int32
}
func (transport *fakeExecSSHTransport) NewExec(options sshexec.Options) (sshExecRunner, error) {
if transport.newExec != nil {
return transport.newExec(options)
}
return &fakeExec{}, nil
}
func (transport *fakeExecSSHTransport) Close() error {
transport.closeCalls.Add(1)
return nil
}
func TestExecSSHTransportCacheConcurrentGetOrCreateReusesOneTransport(t *testing.T) {
cache := newExecSSHTransportCache()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
createStarted := make(chan struct{})
releaseCreate := make(chan struct{})
var createCalls atomic.Int32
transport := &fakeExecSSHTransport{}
create := func() (execSSHTransport, error) {
if createCalls.Add(1) == 1 {
close(createStarted)
}
<-releaseCreate
return transport, nil
}
const callersCount = 16
transports := make([]execSSHTransport, callersCount)
reused := make([]bool, callersCount)
errCh := make(chan error, callersCount)
var wg sync.WaitGroup
wg.Add(callersCount)
for i := range callersCount {
go func() {
defer wg.Done()
var err error
transports[i], reused[i], err = cache.getOrCreate(key, create)
if err != nil {
errCh <- err
}
}()
}
<-createStarted
close(releaseCreate)
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
require.EqualValues(t, 1, createCalls.Load())
for _, cachedTransport := range transports {
require.Same(t, transport, cachedTransport)
}
require.Equal(t, callersCount-1, countTrue(reused))
}
func TestExecSSHTransportCacheSeparatesVMIncarnations(t *testing.T) {
cache := newExecSSHTransportCache()
var createCalls atomic.Int32
create := func() (execSSHTransport, error) {
createCalls.Add(1)
return &fakeExecSSHTransport{}, nil
}
firstTransport, _, err := cache.getOrCreate(
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}, create)
require.NoError(t, err)
secondTransport, _, err := cache.getOrCreate(
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 2}, create)
require.NoError(t, err)
require.EqualValues(t, 2, createCalls.Load())
require.NotSame(t, firstTransport, secondTransport)
}
func TestExecSSHTransportCacheDiscardClosesExpectedTransport(t *testing.T) {
cache := newExecSSHTransportCache()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
transport := &fakeExecSSHTransport{}
cachedTransport, _, err := cache.getOrCreate(key, func() (execSSHTransport, error) {
return transport, nil
})
require.NoError(t, err)
cache.discard(key, cachedTransport)
require.EqualValues(t, 1, transport.closeCalls.Load())
require.Empty(t, cache.entries)
}
func TestExecSSHTransportCacheDiscardIgnoresReplacedTransport(t *testing.T) {
cache := newExecSSHTransportCache()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
originalTransport := &fakeExecSSHTransport{}
replacementTransport := &fakeExecSSHTransport{}
cachedTransport, _, err := cache.getOrCreate(key, func() (execSSHTransport, error) {
return originalTransport, nil
})
require.NoError(t, err)
cache.entries[key] = replacementTransport
cache.discard(key, cachedTransport)
require.EqualValues(t, 0, originalTransport.closeCalls.Load())
require.Same(t, replacementTransport, cache.entries[key])
}
func TestExecSSHTransportCacheKeepsTransportAcrossExecs(t *testing.T) {
cache := newExecSSHTransportCache()
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
transport := &fakeExecSSHTransport{
newExec: func(sshexec.Options) (sshExecRunner, error) {
return nil, errors.New("failed to open channel")
},
}
var createCalls atomic.Int32
create := func() (execSSHTransport, error) {
createCalls.Add(1)
return transport, nil
}
firstTransport, reused, err := cache.getOrCreate(key, create)
require.NoError(t, err)
require.False(t, reused)
secondTransport, reused, err := cache.getOrCreate(key, create)
require.NoError(t, err)
require.True(t, reused)
require.Same(t, firstTransport, secondTransport)
require.EqualValues(t, 1, createCalls.Load())
require.EqualValues(t, 0, transport.closeCalls.Load())
}
func TestExecSSHTransportCacheCloseAllClosesTransports(t *testing.T) {
cache := newExecSSHTransportCache()
transport := &fakeExecSSHTransport{}
_, _, err := cache.getOrCreate(
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1},
func() (execSSHTransport, error) {
return transport, nil
})
require.NoError(t, err)
cache.closeAll()
require.EqualValues(t, 1, transport.closeCalls.Load())
require.Empty(t, cache.entries)
}
func countTrue(values []bool) int {
count := 0
for _, value := range values {
if value {
count++
}
}
return count
}