171 lines
3.5 KiB
Go
171 lines
3.5 KiB
Go
package controller
|
|
|
|
import (
|
|
"errors"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/cirruslabs/orchard/internal/controller/sshexec"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type fakeSSHExecClient struct {
|
|
keepaliveErr error
|
|
|
|
closeCalls atomic.Int32
|
|
keepaliveCalls atomic.Int32
|
|
|
|
done chan struct{}
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
func newFakeSSHExecClient() *fakeSSHExecClient {
|
|
return &fakeSSHExecClient{
|
|
done: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (client *fakeSSHExecClient) NewExec(options sshexec.Options) (*sshexec.Exec, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (client *fakeSSHExecClient) Keepalive() error {
|
|
client.keepaliveCalls.Add(1)
|
|
|
|
return client.keepaliveErr
|
|
}
|
|
|
|
func (client *fakeSSHExecClient) Done() <-chan struct{} {
|
|
return client.done
|
|
}
|
|
|
|
func (client *fakeSSHExecClient) Err() error {
|
|
<-client.done
|
|
|
|
return errors.New("disconnected")
|
|
}
|
|
|
|
func (client *fakeSSHExecClient) Close() error {
|
|
client.closeCalls.Add(1)
|
|
client.closeOnce.Do(func() {
|
|
close(client.done)
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (client *fakeSSHExecClient) ShouldRecreateAfter(err error) bool {
|
|
return true
|
|
}
|
|
|
|
func TestExecSSHClientPoolDeduplicatesConcurrentInitialization(t *testing.T) {
|
|
pool := newExecSSHClientPool(0, zap.NewNop().Sugar())
|
|
client := newFakeSSHExecClient()
|
|
|
|
createStarted := make(chan struct{})
|
|
releaseCreate := make(chan struct{})
|
|
|
|
var createCalls atomic.Int32
|
|
create := func() (sshExecClient, error) {
|
|
if createCalls.Add(1) == 1 {
|
|
close(createStarted)
|
|
}
|
|
|
|
<-releaseCreate
|
|
|
|
return client, nil
|
|
}
|
|
|
|
const callers = 16
|
|
|
|
start := make(chan struct{})
|
|
errCh := make(chan error, callers)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(callers)
|
|
|
|
for range callers {
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
<-start
|
|
|
|
_, err := pool.newExec("vm-1", sshexec.Options{}, create)
|
|
if err != nil {
|
|
errCh <- err
|
|
}
|
|
}()
|
|
}
|
|
|
|
close(start)
|
|
<-createStarted
|
|
close(releaseCreate)
|
|
|
|
wg.Wait()
|
|
close(errCh)
|
|
|
|
for err := range errCh {
|
|
require.NoError(t, err)
|
|
}
|
|
require.EqualValues(t, 1, createCalls.Load())
|
|
|
|
pool.closeAll()
|
|
}
|
|
|
|
func TestExecSSHClientPoolKeepaliveInvalidatesClient(t *testing.T) {
|
|
pool := newExecSSHClientPool(time.Millisecond, zap.NewNop().Sugar())
|
|
client := newFakeSSHExecClient()
|
|
client.keepaliveErr = errors.New("boom")
|
|
|
|
_, err := pool.newExec("vm-1", sshexec.Options{}, func() (sshExecClient, error) {
|
|
return client, nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
return client.keepaliveCalls.Load() > 0 && client.closeCalls.Load() == 1
|
|
}, time.Second, time.Millisecond)
|
|
|
|
replacement := newFakeSSHExecClient()
|
|
var createCalls atomic.Int32
|
|
_, err = pool.newExec("vm-1", sshexec.Options{}, func() (sshExecClient, error) {
|
|
createCalls.Add(1)
|
|
|
|
return replacement, nil
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, createCalls.Load())
|
|
|
|
pool.closeAll()
|
|
}
|
|
|
|
func TestExecSSHClientPoolWaitClearsDisconnectedClient(t *testing.T) {
|
|
pool := newExecSSHClientPool(0, zap.NewNop().Sugar())
|
|
client := newFakeSSHExecClient()
|
|
|
|
_, err := pool.newExec("vm-1", sshexec.Options{}, func() (sshExecClient, error) {
|
|
return client, nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, client.Close())
|
|
|
|
replacement := newFakeSSHExecClient()
|
|
var createCalls atomic.Int32
|
|
require.Eventually(t, func() bool {
|
|
_, err := pool.newExec("vm-1", sshexec.Options{}, func() (sshExecClient, error) {
|
|
createCalls.Add(1)
|
|
|
|
return replacement, nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
return createCalls.Load() == 1
|
|
}, time.Second, time.Millisecond)
|
|
|
|
pool.closeAll()
|
|
}
|