Refactor executeVM logic and introduce helper functions for request parsing, WebSocket handling, and session execution. Add related tests.
This commit is contained in:
parent
dc0f5b45d0
commit
9c55014cc8
|
|
@ -32,44 +32,98 @@ type sshExecution struct {
|
|||
stdin io.WriteCloser
|
||||
}
|
||||
|
||||
type executeVMRequest struct {
|
||||
name string
|
||||
command string
|
||||
args []string
|
||||
wait time.Duration
|
||||
}
|
||||
|
||||
type executeSessionChannels struct {
|
||||
outputFrameCh chan execstream.Frame
|
||||
outputDoneCh chan struct{}
|
||||
outputErrCh chan error
|
||||
stdinErrCh chan error
|
||||
exitCodeCh chan int32
|
||||
exitErrCh chan error
|
||||
}
|
||||
|
||||
func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder {
|
||||
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
||||
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
||||
return responder
|
||||
}
|
||||
|
||||
name := ctx.Param("name")
|
||||
|
||||
command := ctx.Query("command")
|
||||
if command == "" {
|
||||
return responder.Code(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
args := ctx.QueryArray("arg")
|
||||
|
||||
waitRaw := ctx.DefaultQuery("wait", "10")
|
||||
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
||||
if err != nil {
|
||||
return responder.Code(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
waitCtx, waitCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
||||
defer waitCancel()
|
||||
|
||||
vm, responderImpl := controller.waitForVM(waitCtx, name)
|
||||
request, responderImpl := parseExecuteVMRequest(ctx)
|
||||
if responderImpl != nil {
|
||||
return responderImpl
|
||||
}
|
||||
|
||||
rvCtx, rvCancel := context.WithCancel(ctx)
|
||||
defer rvCancel()
|
||||
waitCtx, waitCancel := context.WithTimeout(ctx, request.wait)
|
||||
defer waitCancel()
|
||||
|
||||
vm, responderImpl := controller.waitForVM(waitCtx, request.name)
|
||||
if responderImpl != nil {
|
||||
return responderImpl
|
||||
}
|
||||
|
||||
tunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, waitCtx, vm)
|
||||
if responderImpl != nil {
|
||||
return responderImpl
|
||||
}
|
||||
defer func() {
|
||||
_ = tunnel.Close()
|
||||
}()
|
||||
|
||||
wsConn, err := acceptExecuteWebSocket(ctx)
|
||||
if err != nil {
|
||||
return responder.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = wsConn.CloseNow()
|
||||
}()
|
||||
|
||||
return controller.executeVMViaSSHTunnel(ctx, tunnel, wsConn, vm, request.command, request.args)
|
||||
}
|
||||
|
||||
func parseExecuteVMRequest(ctx *gin.Context) (*executeVMRequest, responder.Responder) {
|
||||
command := ctx.Query("command")
|
||||
if command == "" {
|
||||
return nil, responder.Code(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
waitRaw := ctx.DefaultQuery("wait", "10")
|
||||
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
||||
if err != nil {
|
||||
return nil, responder.Code(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
return &executeVMRequest{
|
||||
name: ctx.Param("name"),
|
||||
command: command,
|
||||
args: ctx.QueryArray("arg"),
|
||||
wait: time.Duration(wait) * time.Second,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func acceptExecuteWebSocket(ctx *gin.Context) (*websocket.Conn, error) {
|
||||
return websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
})
|
||||
}
|
||||
|
||||
func (controller *Controller) establishExecuteSSHTunnel(
|
||||
ctx context.Context,
|
||||
waitCtx context.Context,
|
||||
vm *v1.VM,
|
||||
) (net.Conn, responder.Responder) {
|
||||
rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx)
|
||||
|
||||
session := uuid.New().String()
|
||||
connCh, cancelRequest := controller.connRendezvous.Request(rendezvousCtx, session)
|
||||
defer cancelRequest()
|
||||
|
||||
connCh, cancel := controller.connRendezvous.Request(rvCtx, session)
|
||||
defer cancel()
|
||||
|
||||
err = controller.workerNotifier.Notify(waitCtx, vm.Worker, &rpc.WatchInstruction{
|
||||
err := controller.workerNotifier.Notify(waitCtx, vm.Worker, &rpc.WatchInstruction{
|
||||
Action: &rpc.WatchInstruction_PortForwardAction{
|
||||
PortForwardAction: &rpc.WatchInstruction_PortForward{
|
||||
Session: session,
|
||||
|
|
@ -79,56 +133,49 @@ func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder {
|
|||
},
|
||||
})
|
||||
if err != nil {
|
||||
rendezvousCtxCancel()
|
||||
|
||||
controller.logger.Warnf("failed to request VM SSH port-forwarding from the worker %s: %v",
|
||||
vm.Worker, err)
|
||||
|
||||
return responder.Code(http.StatusServiceUnavailable)
|
||||
return nil, responder.Code(http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
timeoutTimer := time.NewTimer(executeSessionRendezvousTimeout)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
select {
|
||||
case rvResp := <-connCh:
|
||||
if rvResp.ErrorMessage != "" {
|
||||
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
|
||||
"failed to establish SSH connection to the VM on the worker: %s", rvResp.ErrorMessage))
|
||||
case rendezvousResp := <-connCh:
|
||||
if rendezvousResp.ErrorMessage != "" {
|
||||
rendezvousCtxCancel()
|
||||
|
||||
return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
|
||||
"failed to establish SSH connection to the VM on the worker: %s", rendezvousResp.ErrorMessage))
|
||||
}
|
||||
|
||||
if rvResp.Result == nil {
|
||||
return responder.Code(http.StatusServiceUnavailable)
|
||||
if rendezvousResp.Result == nil {
|
||||
rendezvousCtxCancel()
|
||||
|
||||
return nil, responder.Code(http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
ws, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
})
|
||||
if err != nil {
|
||||
_ = rvResp.Result.Close()
|
||||
|
||||
return responder.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = ws.CloseNow()
|
||||
}()
|
||||
|
||||
tunnel := netconncancel.New(rvResp.Result, rvCancel)
|
||||
defer func() {
|
||||
_ = tunnel.Close()
|
||||
}()
|
||||
|
||||
return controller.executeVMViaSSHTunnel(ctx, tunnel, ws, vm, command, args)
|
||||
return netconncancel.New(rendezvousResp.Result, rendezvousCtxCancel), nil
|
||||
case <-timeoutTimer.C:
|
||||
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
|
||||
rendezvousCtxCancel()
|
||||
|
||||
return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
|
||||
"timed out waiting for worker %s to establish SSH tunnel", vm.Worker))
|
||||
case <-ctx.Done():
|
||||
return responder.Error(ctx.Err())
|
||||
rendezvousCtxCancel()
|
||||
|
||||
return nil, responder.Error(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel net.Conn, ws *websocket.Conn, vm *v1.VM, cmd string, args []string) responder.Responder {
|
||||
sshClient, err := newSSHClient(tunnel, vm)
|
||||
if err != nil {
|
||||
controller.closeExecuteWithFrameError(ctx, ws, nil,
|
||||
controller.closeExecuteWithFrameError(ws, nil,
|
||||
fmt.Sprintf("SSH handshake with the VM failed: %v", err))
|
||||
|
||||
return responder.Empty()
|
||||
|
|
@ -139,7 +186,7 @@ func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel
|
|||
|
||||
execution, err := startSSHExecution(sshClient, cmd, args)
|
||||
if err != nil {
|
||||
controller.closeExecuteWithFrameError(ctx, ws, nil, err.Error())
|
||||
controller.closeExecuteWithFrameError(ws, nil, err.Error())
|
||||
|
||||
return responder.Empty()
|
||||
}
|
||||
|
|
@ -147,7 +194,7 @@ func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel
|
|||
_ = execution.session.Close()
|
||||
}()
|
||||
|
||||
return controller.pumpExecuteFrames(ctx, ws, execution)
|
||||
return controller.runExecuteSession(ctx, ws, execution)
|
||||
}
|
||||
|
||||
func newSSHClient(conn net.Conn, vm *v1.VM) (*ssh.Client, error) {
|
||||
|
|
@ -217,12 +264,61 @@ func startSSHExecution(client *ssh.Client, cmd string, args []string) (*sshExecu
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (controller *Controller) pumpExecuteFrames(
|
||||
func startExecuteSessionChannels(
|
||||
sessionCtx context.Context,
|
||||
decoder *json.Decoder,
|
||||
execution *sshExecution,
|
||||
) executeSessionChannels {
|
||||
channels := executeSessionChannels{
|
||||
outputFrameCh: make(chan execstream.Frame, 16),
|
||||
outputDoneCh: make(chan struct{}, 2),
|
||||
outputErrCh: make(chan error, 2),
|
||||
stdinErrCh: make(chan error, 1),
|
||||
exitCodeCh: make(chan int32, 1),
|
||||
exitErrCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
go forwardSSHOutputFrames(sessionCtx, execution.stdout, execstream.FrameTypeStdout,
|
||||
channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh)
|
||||
go forwardSSHOutputFrames(sessionCtx, execution.stderr, execstream.FrameTypeStderr,
|
||||
channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh)
|
||||
go consumeClientInputFrames(decoder, execution.stdin, channels.stdinErrCh)
|
||||
go waitForSSHExecutionExit(execution.session, channels.exitCodeCh, channels.exitErrCh)
|
||||
|
||||
return channels
|
||||
}
|
||||
|
||||
func flushExecuteOutputFrames(encoder *json.Encoder, outputFrameCh <-chan execstream.Frame) error {
|
||||
for len(outputFrameCh) > 0 {
|
||||
frame := <-outputFrameCh
|
||||
if err := execstream.WriteFrame(encoder, &frame); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func firstExecuteOutputError(outputErrCh <-chan error) error {
|
||||
for len(outputErrCh) > 0 {
|
||||
err := <-outputErrCh
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (controller *Controller) runExecuteSession(
|
||||
ctx context.Context,
|
||||
ws *websocket.Conn,
|
||||
execution *sshExecution,
|
||||
) responder.Responder {
|
||||
wsNetConn := websocket.NetConn(ctx, ws, websocket.MessageText)
|
||||
sessionCtx, sessionCancel := context.WithCancel(ctx)
|
||||
defer sessionCancel()
|
||||
|
||||
wsNetConn := websocket.NetConn(sessionCtx, ws, websocket.MessageText)
|
||||
defer func() {
|
||||
_ = wsNetConn.Close()
|
||||
}()
|
||||
|
|
@ -230,33 +326,27 @@ func (controller *Controller) pumpExecuteFrames(
|
|||
encoder := execstream.NewEncoder(wsNetConn)
|
||||
decoder := execstream.NewDecoder(wsNetConn)
|
||||
|
||||
outCh := make(chan execstream.Frame, 16)
|
||||
outDoneCh := make(chan struct{}, 2)
|
||||
outErrCh := make(chan error, 1)
|
||||
stdinErrCh := make(chan error, 1)
|
||||
exitCh := make(chan int32, 1)
|
||||
exitErrCh := make(chan error, 1)
|
||||
|
||||
go streamExecuteOutput(execution.stdout, execstream.FrameTypeStdout, outCh, outDoneCh, outErrCh)
|
||||
go streamExecuteOutput(execution.stderr, execstream.FrameTypeStderr, outCh, outDoneCh, outErrCh)
|
||||
go streamExecuteClientFrames(decoder, execution.stdin, stdinErrCh)
|
||||
go waitForSSHExecutionExit(execution.session, exitCh, exitErrCh)
|
||||
channels := startExecuteSessionChannels(sessionCtx, decoder, execution)
|
||||
|
||||
pingTicker := time.NewTicker(controller.pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
readersDone := 0
|
||||
exitObserved := false
|
||||
outputReadersDone := 0
|
||||
exitSeen := false
|
||||
var exitCode int32
|
||||
|
||||
for {
|
||||
if exitObserved && readersDone >= 2 {
|
||||
for len(outCh) > 0 {
|
||||
frame := <-outCh
|
||||
if err := execstream.WriteFrame(encoder, &frame); err != nil {
|
||||
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
|
||||
"failed to stream execute output to the client", err)
|
||||
}
|
||||
if exitSeen && outputReadersDone >= 2 {
|
||||
if outputErr := firstExecuteOutputError(channels.outputErrCh); outputErr != nil {
|
||||
controller.closeExecuteWithFrameError(ws, encoder,
|
||||
fmt.Sprintf("failed while streaming command output: %v", outputErr))
|
||||
|
||||
return responder.Empty()
|
||||
}
|
||||
|
||||
if err := flushExecuteOutputFrames(encoder, channels.outputFrameCh); err != nil {
|
||||
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
|
||||
"failed to stream execute output to the client", err)
|
||||
}
|
||||
|
||||
if err := execstream.WriteFrame(encoder, &execstream.Frame{
|
||||
|
|
@ -276,25 +366,25 @@ func (controller *Controller) pumpExecuteFrames(
|
|||
}
|
||||
|
||||
select {
|
||||
case frame := <-outCh:
|
||||
case frame := <-channels.outputFrameCh:
|
||||
if err := execstream.WriteFrame(encoder, &frame); err != nil {
|
||||
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
|
||||
"failed to stream execute output to the client", err)
|
||||
}
|
||||
case <-outDoneCh:
|
||||
readersDone++
|
||||
case err := <-outErrCh:
|
||||
case <-channels.outputDoneCh:
|
||||
outputReadersDone++
|
||||
case err := <-channels.outputErrCh:
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
controller.closeExecuteWithFrameError(ctx, ws, encoder,
|
||||
controller.closeExecuteWithFrameError(ws, encoder,
|
||||
fmt.Sprintf("failed while streaming command output: %v", err))
|
||||
|
||||
return responder.Empty()
|
||||
case err := <-stdinErrCh:
|
||||
case err := <-channels.stdinErrCh:
|
||||
if err == nil || errors.Is(err, context.Canceled) {
|
||||
stdinErrCh = nil
|
||||
channels.stdinErrCh = nil
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
@ -302,15 +392,15 @@ func (controller *Controller) pumpExecuteFrames(
|
|||
return responder.Empty()
|
||||
}
|
||||
|
||||
controller.closeExecuteWithFrameError(ctx, ws, encoder,
|
||||
controller.closeExecuteWithFrameError(ws, encoder,
|
||||
fmt.Sprintf("failed while reading command stdin stream: %v", err))
|
||||
|
||||
return responder.Empty()
|
||||
case code := <-exitCh:
|
||||
exitObserved = true
|
||||
case code := <-channels.exitCodeCh:
|
||||
exitSeen = true
|
||||
exitCode = code
|
||||
case err := <-exitErrCh:
|
||||
controller.closeExecuteWithFrameError(ctx, ws, encoder,
|
||||
case err := <-channels.exitErrCh:
|
||||
controller.closeExecuteWithFrameError(ws, encoder,
|
||||
fmt.Sprintf("failed while waiting for command completion: %v", err))
|
||||
|
||||
return responder.Empty()
|
||||
|
|
@ -346,7 +436,7 @@ func waitForSSHExecutionExit(sshSession *ssh.Session, exitCodeCh chan<- int32, e
|
|||
exitCodeCh <- 0
|
||||
}
|
||||
|
||||
func streamExecuteClientFrames(
|
||||
func consumeClientInputFrames(
|
||||
decoder *json.Decoder,
|
||||
stdin io.WriteCloser,
|
||||
errCh chan<- error,
|
||||
|
|
@ -409,25 +499,33 @@ func streamExecuteClientFrames(
|
|||
}
|
||||
}
|
||||
|
||||
func streamExecuteOutput(
|
||||
func forwardSSHOutputFrames(
|
||||
ctx context.Context,
|
||||
reader io.Reader,
|
||||
frameType execstream.FrameType,
|
||||
outputCh chan<- execstream.Frame,
|
||||
doneCh chan<- struct{},
|
||||
errCh chan<- error,
|
||||
outputFrameCh chan<- execstream.Frame,
|
||||
outputDoneCh chan<- struct{},
|
||||
outputErrCh chan<- error,
|
||||
) {
|
||||
defer func() {
|
||||
doneCh <- struct{}{}
|
||||
outputDoneCh <- struct{}{}
|
||||
}()
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
buffer := make([]byte, 4096)
|
||||
n, err := reader.Read(buffer)
|
||||
if n > 0 {
|
||||
outputCh <- execstream.Frame{
|
||||
frame := execstream.Frame{
|
||||
Type: frameType,
|
||||
Data: append([]byte(nil), buffer[:n]...),
|
||||
}
|
||||
|
||||
select {
|
||||
case outputFrameCh <- frame:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
|
|
@ -436,7 +534,7 @@ func streamExecuteOutput(
|
|||
|
||||
if err != nil {
|
||||
select {
|
||||
case errCh <- err:
|
||||
case outputErrCh <- err:
|
||||
default:
|
||||
}
|
||||
|
||||
|
|
@ -464,7 +562,6 @@ func shellQuoteArg(arg string) string {
|
|||
}
|
||||
|
||||
func (controller *Controller) closeExecuteWithFrameError(
|
||||
ctx context.Context,
|
||||
wsConn *websocket.Conn,
|
||||
encoder *json.Encoder,
|
||||
message string,
|
||||
|
|
|
|||
|
|
@ -3,8 +3,11 @@ package controller
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -21,7 +24,7 @@ func (writer *recordingWriteCloser) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func TestStreamExecuteClientFramesWritesInputAndClosesOnEOFFrame(t *testing.T) {
|
||||
func TestConsumeClientInputFramesWritesInputAndClosesOnEOFFrame(t *testing.T) {
|
||||
var input bytes.Buffer
|
||||
encoder := execstream.NewEncoder(&input)
|
||||
|
||||
|
|
@ -41,14 +44,14 @@ func TestStreamExecuteClientFramesWritesInputAndClosesOnEOFFrame(t *testing.T) {
|
|||
stdin := &recordingWriteCloser{}
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
streamExecuteClientFrames(decoder, stdin, errCh)
|
||||
consumeClientInputFrames(decoder, stdin, errCh)
|
||||
|
||||
require.NoError(t, <-errCh)
|
||||
require.True(t, stdin.closed)
|
||||
require.Equal(t, "hello", stdin.String())
|
||||
}
|
||||
|
||||
func TestStreamExecuteClientFramesUnsupportedType(t *testing.T) {
|
||||
func TestConsumeClientInputFramesUnsupportedType(t *testing.T) {
|
||||
var input bytes.Buffer
|
||||
encoder := execstream.NewEncoder(&input)
|
||||
|
||||
|
|
@ -61,33 +64,33 @@ func TestStreamExecuteClientFramesUnsupportedType(t *testing.T) {
|
|||
stdin := &recordingWriteCloser{}
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
streamExecuteClientFrames(decoder, stdin, errCh)
|
||||
consumeClientInputFrames(decoder, stdin, errCh)
|
||||
|
||||
require.EqualError(t, <-errCh, "unsupported frame type \"stdout\" received from client")
|
||||
require.False(t, stdin.closed)
|
||||
}
|
||||
|
||||
func TestStreamExecuteClientFramesClosesStdinOnDecodeError(t *testing.T) {
|
||||
func TestConsumeClientInputFramesClosesStdinOnDecodeError(t *testing.T) {
|
||||
decoder := execstream.NewDecoder(bytes.NewBuffer(nil))
|
||||
stdin := &recordingWriteCloser{}
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
streamExecuteClientFrames(decoder, stdin, errCh)
|
||||
consumeClientInputFrames(decoder, stdin, errCh)
|
||||
|
||||
require.ErrorIs(t, <-errCh, io.EOF)
|
||||
require.True(t, stdin.closed)
|
||||
}
|
||||
|
||||
func TestStreamExecuteOutputEmitsFrameAndSignalsDone(t *testing.T) {
|
||||
outputCh := make(chan execstream.Frame, 1)
|
||||
doneCh := make(chan struct{}, 1)
|
||||
errCh := make(chan error, 1)
|
||||
func TestForwardSSHOutputFramesEmitsFrameAndSignalsDone(t *testing.T) {
|
||||
outputFrameCh := make(chan execstream.Frame, 1)
|
||||
outputDoneCh := make(chan struct{}, 1)
|
||||
outputErrCh := make(chan error, 1)
|
||||
|
||||
streamExecuteOutput(bytes.NewBufferString("payload"),
|
||||
execstream.FrameTypeStderr, outputCh, doneCh, errCh)
|
||||
forwardSSHOutputFrames(context.Background(), bytes.NewBufferString("payload"),
|
||||
execstream.FrameTypeStderr, outputFrameCh, outputDoneCh, outputErrCh)
|
||||
|
||||
select {
|
||||
case frame := <-outputCh:
|
||||
case frame := <-outputFrameCh:
|
||||
require.Equal(t, execstream.FrameTypeStderr, frame.Type)
|
||||
require.Equal(t, []byte("payload"), frame.Data)
|
||||
default:
|
||||
|
|
@ -95,18 +98,68 @@ func TestStreamExecuteOutputEmitsFrameAndSignalsDone(t *testing.T) {
|
|||
}
|
||||
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-outputDoneCh:
|
||||
default:
|
||||
t.Fatal("expected done signal")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
case err := <-outputErrCh:
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardSSHOutputFramesStopsWhenContextCancelledWhileOutputChannelIsBlocked(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
outputFrameCh := make(chan execstream.Frame, 1)
|
||||
outputFrameCh <- execstream.Frame{
|
||||
Type: execstream.FrameTypeStdout,
|
||||
Data: []byte("occupied"),
|
||||
}
|
||||
|
||||
outputDoneCh := make(chan struct{}, 1)
|
||||
outputErrCh := make(chan error, 1)
|
||||
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
forwardSSHOutputFrames(ctx, bytes.NewBufferString("payload"),
|
||||
execstream.FrameTypeStdout, outputFrameCh, outputDoneCh, outputErrCh)
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
t.Fatal("forwardSSHOutputFrames unexpectedly returned before context cancellation")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("forwardSSHOutputFrames did not return after context cancellation")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-outputDoneCh:
|
||||
default:
|
||||
t.Fatal("expected done signal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstExecuteOutputErrorReturnsFirstNonNilError(t *testing.T) {
|
||||
outputErrCh := make(chan error, 3)
|
||||
outputErrCh <- nil
|
||||
outputErrCh <- errors.New("first error")
|
||||
outputErrCh <- errors.New("second error")
|
||||
|
||||
require.EqualError(t, firstExecuteOutputError(outputErrCh), "first error")
|
||||
}
|
||||
|
||||
func TestBuildSSHCommandQuotesArguments(t *testing.T) {
|
||||
result := buildSSHCommand("echo", []string{"hello world", "a'b", ""})
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue