orchard/internal/controller/exec_ssh_clients_test.go

171 lines
3.5 KiB
Go

package controller
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/cirruslabs/orchard/internal/controller/sshexec"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
type fakeSSHExecClient struct {
keepaliveErr error
closeCalls atomic.Int32
keepaliveCalls atomic.Int32
done chan struct{}
closeOnce sync.Once
}
func newFakeSSHExecClient() *fakeSSHExecClient {
return &fakeSSHExecClient{
done: make(chan struct{}),
}
}
func (client *fakeSSHExecClient) NewExec(options sshexec.Options) (*sshexec.Exec, error) {
return nil, nil
}
func (client *fakeSSHExecClient) Keepalive() error {
client.keepaliveCalls.Add(1)
return client.keepaliveErr
}
func (client *fakeSSHExecClient) Done() <-chan struct{} {
return client.done
}
func (client *fakeSSHExecClient) Err() error {
<-client.done
return errors.New("disconnected")
}
func (client *fakeSSHExecClient) Close() error {
client.closeCalls.Add(1)
client.closeOnce.Do(func() {
close(client.done)
})
return nil
}
func (client *fakeSSHExecClient) ShouldRecreateAfter(err error) bool {
return true
}
func TestExecSSHClientPoolDeduplicatesConcurrentInitialization(t *testing.T) {
pool := newExecSSHClientPool(0, zap.NewNop().Sugar())
client := newFakeSSHExecClient()
createStarted := make(chan struct{})
releaseCreate := make(chan struct{})
var createCalls atomic.Int32
create := func() (sshExecClient, error) {
if createCalls.Add(1) == 1 {
close(createStarted)
}
<-releaseCreate
return client, nil
}
const callers = 16
start := make(chan struct{})
errCh := make(chan error, callers)
var wg sync.WaitGroup
wg.Add(callers)
for range callers {
go func() {
defer wg.Done()
<-start
_, err := pool.newExec("vm-1", sshexec.Options{}, create)
if err != nil {
errCh <- err
}
}()
}
close(start)
<-createStarted
close(releaseCreate)
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
require.EqualValues(t, 1, createCalls.Load())
pool.closeAll()
}
func TestExecSSHClientPoolKeepaliveInvalidatesClient(t *testing.T) {
pool := newExecSSHClientPool(time.Millisecond, zap.NewNop().Sugar())
client := newFakeSSHExecClient()
client.keepaliveErr = errors.New("boom")
_, err := pool.newExec("vm-1", sshexec.Options{}, func() (sshExecClient, error) {
return client, nil
})
require.NoError(t, err)
require.Eventually(t, func() bool {
return client.keepaliveCalls.Load() > 0 && client.closeCalls.Load() == 1
}, time.Second, time.Millisecond)
replacement := newFakeSSHExecClient()
var createCalls atomic.Int32
_, err = pool.newExec("vm-1", sshexec.Options{}, func() (sshExecClient, error) {
createCalls.Add(1)
return replacement, nil
})
require.NoError(t, err)
require.EqualValues(t, 1, createCalls.Load())
pool.closeAll()
}
func TestExecSSHClientPoolWaitClearsDisconnectedClient(t *testing.T) {
pool := newExecSSHClientPool(0, zap.NewNop().Sugar())
client := newFakeSSHExecClient()
_, err := pool.newExec("vm-1", sshexec.Options{}, func() (sshExecClient, error) {
return client, nil
})
require.NoError(t, err)
require.NoError(t, client.Close())
replacement := newFakeSSHExecClient()
var createCalls atomic.Int32
require.Eventually(t, func() bool {
_, err := pool.newExec("vm-1", sshexec.Options{}, func() (sshExecClient, error) {
createCalls.Add(1)
return replacement, nil
})
require.NoError(t, err)
return createCalls.Load() == 1
}, time.Second, time.Millisecond)
pool.closeAll()
}