194 lines
4.9 KiB
Go
194 lines
4.9 KiB
Go
package controller
|
|
|
|
import (
|
|
"errors"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/cirruslabs/orchard/internal/controller/sshexec"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type fakeExecSSHTransport struct {
|
|
newExec func(sshexec.Options) (sshExecRunner, error)
|
|
|
|
closeCalls atomic.Int32
|
|
}
|
|
|
|
func (transport *fakeExecSSHTransport) NewExec(options sshexec.Options) (sshExecRunner, error) {
|
|
if transport.newExec != nil {
|
|
return transport.newExec(options)
|
|
}
|
|
|
|
return &fakeExec{}, nil
|
|
}
|
|
|
|
func (transport *fakeExecSSHTransport) Close() error {
|
|
transport.closeCalls.Add(1)
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestExecSSHTransportCacheConcurrentGetOrCreateReusesOneTransport(t *testing.T) {
|
|
cache := newExecSSHTransportCache()
|
|
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
|
|
|
|
createStarted := make(chan struct{})
|
|
releaseCreate := make(chan struct{})
|
|
var createCalls atomic.Int32
|
|
transport := &fakeExecSSHTransport{}
|
|
|
|
create := func() (execSSHTransport, error) {
|
|
if createCalls.Add(1) == 1 {
|
|
close(createStarted)
|
|
}
|
|
<-releaseCreate
|
|
|
|
return transport, nil
|
|
}
|
|
|
|
const callersCount = 16
|
|
transports := make([]execSSHTransport, callersCount)
|
|
reused := make([]bool, callersCount)
|
|
errCh := make(chan error, callersCount)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(callersCount)
|
|
for i := range callersCount {
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
var err error
|
|
transports[i], reused[i], err = cache.getOrCreate(key, create)
|
|
if err != nil {
|
|
errCh <- err
|
|
}
|
|
}()
|
|
}
|
|
|
|
<-createStarted
|
|
close(releaseCreate)
|
|
wg.Wait()
|
|
close(errCh)
|
|
|
|
for err := range errCh {
|
|
require.NoError(t, err)
|
|
}
|
|
require.EqualValues(t, 1, createCalls.Load())
|
|
|
|
for _, cachedTransport := range transports {
|
|
require.Same(t, transport, cachedTransport)
|
|
}
|
|
require.Equal(t, callersCount-1, countTrue(reused))
|
|
}
|
|
|
|
func TestExecSSHTransportCacheSeparatesVMIncarnations(t *testing.T) {
|
|
cache := newExecSSHTransportCache()
|
|
var createCalls atomic.Int32
|
|
|
|
create := func() (execSSHTransport, error) {
|
|
createCalls.Add(1)
|
|
|
|
return &fakeExecSSHTransport{}, nil
|
|
}
|
|
|
|
firstTransport, _, err := cache.getOrCreate(
|
|
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}, create)
|
|
require.NoError(t, err)
|
|
secondTransport, _, err := cache.getOrCreate(
|
|
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 2}, create)
|
|
require.NoError(t, err)
|
|
|
|
require.EqualValues(t, 2, createCalls.Load())
|
|
require.NotSame(t, firstTransport, secondTransport)
|
|
}
|
|
|
|
func TestExecSSHTransportCacheDiscardClosesExpectedTransport(t *testing.T) {
|
|
cache := newExecSSHTransportCache()
|
|
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
|
|
transport := &fakeExecSSHTransport{}
|
|
|
|
cachedTransport, _, err := cache.getOrCreate(key, func() (execSSHTransport, error) {
|
|
return transport, nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
cache.discard(key, cachedTransport)
|
|
require.EqualValues(t, 1, transport.closeCalls.Load())
|
|
require.Empty(t, cache.entries)
|
|
}
|
|
|
|
func TestExecSSHTransportCacheDiscardIgnoresReplacedTransport(t *testing.T) {
|
|
cache := newExecSSHTransportCache()
|
|
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
|
|
originalTransport := &fakeExecSSHTransport{}
|
|
replacementTransport := &fakeExecSSHTransport{}
|
|
|
|
cachedTransport, _, err := cache.getOrCreate(key, func() (execSSHTransport, error) {
|
|
return originalTransport, nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
cache.entries[key] = replacementTransport
|
|
cache.discard(key, cachedTransport)
|
|
|
|
require.EqualValues(t, 0, originalTransport.closeCalls.Load())
|
|
require.Same(t, replacementTransport, cache.entries[key])
|
|
}
|
|
|
|
func TestExecSSHTransportCacheKeepsTransportAcrossExecs(t *testing.T) {
|
|
cache := newExecSSHTransportCache()
|
|
key := execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1}
|
|
transport := &fakeExecSSHTransport{
|
|
newExec: func(sshexec.Options) (sshExecRunner, error) {
|
|
return nil, errors.New("failed to open channel")
|
|
},
|
|
}
|
|
var createCalls atomic.Int32
|
|
|
|
create := func() (execSSHTransport, error) {
|
|
createCalls.Add(1)
|
|
|
|
return transport, nil
|
|
}
|
|
|
|
firstTransport, reused, err := cache.getOrCreate(key, create)
|
|
require.NoError(t, err)
|
|
require.False(t, reused)
|
|
secondTransport, reused, err := cache.getOrCreate(key, create)
|
|
require.NoError(t, err)
|
|
require.True(t, reused)
|
|
|
|
require.Same(t, firstTransport, secondTransport)
|
|
require.EqualValues(t, 1, createCalls.Load())
|
|
require.EqualValues(t, 0, transport.closeCalls.Load())
|
|
}
|
|
|
|
func TestExecSSHTransportCacheCloseAllClosesTransports(t *testing.T) {
|
|
cache := newExecSSHTransportCache()
|
|
transport := &fakeExecSSHTransport{}
|
|
|
|
_, _, err := cache.getOrCreate(
|
|
execSSHTransportKey{workerName: "worker", vmUID: "vm", restartCount: 1},
|
|
func() (execSSHTransport, error) {
|
|
return transport, nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
cache.closeAll()
|
|
require.EqualValues(t, 1, transport.closeCalls.Load())
|
|
require.Empty(t, cache.entries)
|
|
}
|
|
|
|
func countTrue(values []bool) int {
|
|
count := 0
|
|
for _, value := range values {
|
|
if value {
|
|
count++
|
|
}
|
|
}
|
|
|
|
return count
|
|
}
|