From 9c55014cc880500d0d8a462fdc6c0389dc114c41 Mon Sep 17 00:00:00 2001 From: Fedor Korotkov Date: Sun, 8 Feb 2026 20:41:05 +0100 Subject: [PATCH] Refactor executeVM logic and introduce helper functions for request parsing, WebSocket handling, and session execution. Add related tests. --- internal/controller/api_vms_execute.go | 295 +++++++++++++------- internal/controller/api_vms_execute_test.go | 83 +++++- 2 files changed, 264 insertions(+), 114 deletions(-) diff --git a/internal/controller/api_vms_execute.go b/internal/controller/api_vms_execute.go index c915b3e..83a6240 100644 --- a/internal/controller/api_vms_execute.go +++ b/internal/controller/api_vms_execute.go @@ -32,44 +32,98 @@ type sshExecution struct { stdin io.WriteCloser } +type executeVMRequest struct { + name string + command string + args []string + wait time.Duration +} + +type executeSessionChannels struct { + outputFrameCh chan execstream.Frame + outputDoneCh chan struct{} + outputErrCh chan error + stdinErrCh chan error + exitCodeCh chan int32 + exitErrCh chan error +} + func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder { if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite, v1.ServiceAccountRoleComputeConnect); responder != nil { return responder } - name := ctx.Param("name") - - command := ctx.Query("command") - if command == "" { - return responder.Code(http.StatusBadRequest) - } - - args := ctx.QueryArray("arg") - - waitRaw := ctx.DefaultQuery("wait", "10") - wait, err := strconv.ParseUint(waitRaw, 10, 16) - if err != nil { - return responder.Code(http.StatusBadRequest) - } - - waitCtx, waitCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second) - defer waitCancel() - - vm, responderImpl := controller.waitForVM(waitCtx, name) + request, responderImpl := parseExecuteVMRequest(ctx) if responderImpl != nil { return responderImpl } - rvCtx, rvCancel := context.WithCancel(ctx) - defer rvCancel() + waitCtx, waitCancel := context.WithTimeout(ctx, request.wait) + defer waitCancel() + + vm, responderImpl := controller.waitForVM(waitCtx, request.name) + if responderImpl != nil { + return responderImpl + } + + tunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, waitCtx, vm) + if responderImpl != nil { + return responderImpl + } + defer func() { + _ = tunnel.Close() + }() + + wsConn, err := acceptExecuteWebSocket(ctx) + if err != nil { + return responder.Error(err) + } + defer func() { + _ = wsConn.CloseNow() + }() + + return controller.executeVMViaSSHTunnel(ctx, tunnel, wsConn, vm, request.command, request.args) +} + +func parseExecuteVMRequest(ctx *gin.Context) (*executeVMRequest, responder.Responder) { + command := ctx.Query("command") + if command == "" { + return nil, responder.Code(http.StatusBadRequest) + } + + waitRaw := ctx.DefaultQuery("wait", "10") + wait, err := strconv.ParseUint(waitRaw, 10, 16) + if err != nil { + return nil, responder.Code(http.StatusBadRequest) + } + + return &executeVMRequest{ + name: ctx.Param("name"), + command: command, + args: ctx.QueryArray("arg"), + wait: time.Duration(wait) * time.Second, + }, nil +} + +func acceptExecuteWebSocket(ctx *gin.Context) (*websocket.Conn, error) { + return websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) +} + +func (controller *Controller) establishExecuteSSHTunnel( + ctx context.Context, + waitCtx context.Context, + vm *v1.VM, +) (net.Conn, responder.Responder) { + rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx) session := uuid.New().String() + connCh, cancelRequest := controller.connRendezvous.Request(rendezvousCtx, session) + defer cancelRequest() - connCh, cancel := controller.connRendezvous.Request(rvCtx, session) - defer cancel() - - err = controller.workerNotifier.Notify(waitCtx, vm.Worker, &rpc.WatchInstruction{ + err := controller.workerNotifier.Notify(waitCtx, vm.Worker, &rpc.WatchInstruction{ Action: &rpc.WatchInstruction_PortForwardAction{ PortForwardAction: &rpc.WatchInstruction_PortForward{ Session: session, @@ -79,56 +133,49 @@ func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder { }, }) if err != nil { + rendezvousCtxCancel() + controller.logger.Warnf("failed to request VM SSH port-forwarding from the worker %s: %v", vm.Worker, err) - return responder.Code(http.StatusServiceUnavailable) + return nil, responder.Code(http.StatusServiceUnavailable) } timeoutTimer := time.NewTimer(executeSessionRendezvousTimeout) defer timeoutTimer.Stop() select { - case rvResp := <-connCh: - if rvResp.ErrorMessage != "" { - return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse( - "failed to establish SSH connection to the VM on the worker: %s", rvResp.ErrorMessage)) + case rendezvousResp := <-connCh: + if rendezvousResp.ErrorMessage != "" { + rendezvousCtxCancel() + + return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse( + "failed to establish SSH connection to the VM on the worker: %s", rendezvousResp.ErrorMessage)) } - if rvResp.Result == nil { - return responder.Code(http.StatusServiceUnavailable) + if rendezvousResp.Result == nil { + rendezvousCtxCancel() + + return nil, responder.Code(http.StatusServiceUnavailable) } - ws, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ - OriginPatterns: []string{"*"}, - }) - if err != nil { - _ = rvResp.Result.Close() - - return responder.Error(err) - } - defer func() { - _ = ws.CloseNow() - }() - - tunnel := netconncancel.New(rvResp.Result, rvCancel) - defer func() { - _ = tunnel.Close() - }() - - return controller.executeVMViaSSHTunnel(ctx, tunnel, ws, vm, command, args) + return netconncancel.New(rendezvousResp.Result, rendezvousCtxCancel), nil case <-timeoutTimer.C: - return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse( + rendezvousCtxCancel() + + return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse( "timed out waiting for worker %s to establish SSH tunnel", vm.Worker)) case <-ctx.Done(): - return responder.Error(ctx.Err()) + rendezvousCtxCancel() + + return nil, responder.Error(ctx.Err()) } } func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel net.Conn, ws *websocket.Conn, vm *v1.VM, cmd string, args []string) responder.Responder { sshClient, err := newSSHClient(tunnel, vm) if err != nil { - controller.closeExecuteWithFrameError(ctx, ws, nil, + controller.closeExecuteWithFrameError(ws, nil, fmt.Sprintf("SSH handshake with the VM failed: %v", err)) return responder.Empty() @@ -139,7 +186,7 @@ func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel execution, err := startSSHExecution(sshClient, cmd, args) if err != nil { - controller.closeExecuteWithFrameError(ctx, ws, nil, err.Error()) + controller.closeExecuteWithFrameError(ws, nil, err.Error()) return responder.Empty() } @@ -147,7 +194,7 @@ func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel _ = execution.session.Close() }() - return controller.pumpExecuteFrames(ctx, ws, execution) + return controller.runExecuteSession(ctx, ws, execution) } func newSSHClient(conn net.Conn, vm *v1.VM) (*ssh.Client, error) { @@ -217,12 +264,61 @@ func startSSHExecution(client *ssh.Client, cmd string, args []string) (*sshExecu }, nil } -func (controller *Controller) pumpExecuteFrames( +func startExecuteSessionChannels( + sessionCtx context.Context, + decoder *json.Decoder, + execution *sshExecution, +) executeSessionChannels { + channels := executeSessionChannels{ + outputFrameCh: make(chan execstream.Frame, 16), + outputDoneCh: make(chan struct{}, 2), + outputErrCh: make(chan error, 2), + stdinErrCh: make(chan error, 1), + exitCodeCh: make(chan int32, 1), + exitErrCh: make(chan error, 1), + } + + go forwardSSHOutputFrames(sessionCtx, execution.stdout, execstream.FrameTypeStdout, + channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh) + go forwardSSHOutputFrames(sessionCtx, execution.stderr, execstream.FrameTypeStderr, + channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh) + go consumeClientInputFrames(decoder, execution.stdin, channels.stdinErrCh) + go waitForSSHExecutionExit(execution.session, channels.exitCodeCh, channels.exitErrCh) + + return channels +} + +func flushExecuteOutputFrames(encoder *json.Encoder, outputFrameCh <-chan execstream.Frame) error { + for len(outputFrameCh) > 0 { + frame := <-outputFrameCh + if err := execstream.WriteFrame(encoder, &frame); err != nil { + return err + } + } + + return nil +} + +func firstExecuteOutputError(outputErrCh <-chan error) error { + for len(outputErrCh) > 0 { + err := <-outputErrCh + if err != nil { + return err + } + } + + return nil +} + +func (controller *Controller) runExecuteSession( ctx context.Context, ws *websocket.Conn, execution *sshExecution, ) responder.Responder { - wsNetConn := websocket.NetConn(ctx, ws, websocket.MessageText) + sessionCtx, sessionCancel := context.WithCancel(ctx) + defer sessionCancel() + + wsNetConn := websocket.NetConn(sessionCtx, ws, websocket.MessageText) defer func() { _ = wsNetConn.Close() }() @@ -230,33 +326,27 @@ func (controller *Controller) pumpExecuteFrames( encoder := execstream.NewEncoder(wsNetConn) decoder := execstream.NewDecoder(wsNetConn) - outCh := make(chan execstream.Frame, 16) - outDoneCh := make(chan struct{}, 2) - outErrCh := make(chan error, 1) - stdinErrCh := make(chan error, 1) - exitCh := make(chan int32, 1) - exitErrCh := make(chan error, 1) - - go streamExecuteOutput(execution.stdout, execstream.FrameTypeStdout, outCh, outDoneCh, outErrCh) - go streamExecuteOutput(execution.stderr, execstream.FrameTypeStderr, outCh, outDoneCh, outErrCh) - go streamExecuteClientFrames(decoder, execution.stdin, stdinErrCh) - go waitForSSHExecutionExit(execution.session, exitCh, exitErrCh) + channels := startExecuteSessionChannels(sessionCtx, decoder, execution) pingTicker := time.NewTicker(controller.pingInterval) defer pingTicker.Stop() - readersDone := 0 - exitObserved := false + outputReadersDone := 0 + exitSeen := false var exitCode int32 for { - if exitObserved && readersDone >= 2 { - for len(outCh) > 0 { - frame := <-outCh - if err := execstream.WriteFrame(encoder, &frame); err != nil { - return controller.wsError(ws, websocket.StatusInternalError, "execute session", - "failed to stream execute output to the client", err) - } + if exitSeen && outputReadersDone >= 2 { + if outputErr := firstExecuteOutputError(channels.outputErrCh); outputErr != nil { + controller.closeExecuteWithFrameError(ws, encoder, + fmt.Sprintf("failed while streaming command output: %v", outputErr)) + + return responder.Empty() + } + + if err := flushExecuteOutputFrames(encoder, channels.outputFrameCh); err != nil { + return controller.wsError(ws, websocket.StatusInternalError, "execute session", + "failed to stream execute output to the client", err) } if err := execstream.WriteFrame(encoder, &execstream.Frame{ @@ -276,25 +366,25 @@ func (controller *Controller) pumpExecuteFrames( } select { - case frame := <-outCh: + case frame := <-channels.outputFrameCh: if err := execstream.WriteFrame(encoder, &frame); err != nil { return controller.wsError(ws, websocket.StatusInternalError, "execute session", "failed to stream execute output to the client", err) } - case <-outDoneCh: - readersDone++ - case err := <-outErrCh: + case <-channels.outputDoneCh: + outputReadersDone++ + case err := <-channels.outputErrCh: if err == nil { continue } - controller.closeExecuteWithFrameError(ctx, ws, encoder, + controller.closeExecuteWithFrameError(ws, encoder, fmt.Sprintf("failed while streaming command output: %v", err)) return responder.Empty() - case err := <-stdinErrCh: + case err := <-channels.stdinErrCh: if err == nil || errors.Is(err, context.Canceled) { - stdinErrCh = nil + channels.stdinErrCh = nil continue } @@ -302,15 +392,15 @@ func (controller *Controller) pumpExecuteFrames( return responder.Empty() } - controller.closeExecuteWithFrameError(ctx, ws, encoder, + controller.closeExecuteWithFrameError(ws, encoder, fmt.Sprintf("failed while reading command stdin stream: %v", err)) return responder.Empty() - case code := <-exitCh: - exitObserved = true + case code := <-channels.exitCodeCh: + exitSeen = true exitCode = code - case err := <-exitErrCh: - controller.closeExecuteWithFrameError(ctx, ws, encoder, + case err := <-channels.exitErrCh: + controller.closeExecuteWithFrameError(ws, encoder, fmt.Sprintf("failed while waiting for command completion: %v", err)) return responder.Empty() @@ -346,7 +436,7 @@ func waitForSSHExecutionExit(sshSession *ssh.Session, exitCodeCh chan<- int32, e exitCodeCh <- 0 } -func streamExecuteClientFrames( +func consumeClientInputFrames( decoder *json.Decoder, stdin io.WriteCloser, errCh chan<- error, @@ -409,25 +499,33 @@ func streamExecuteClientFrames( } } -func streamExecuteOutput( +func forwardSSHOutputFrames( + ctx context.Context, reader io.Reader, frameType execstream.FrameType, - outputCh chan<- execstream.Frame, - doneCh chan<- struct{}, - errCh chan<- error, + outputFrameCh chan<- execstream.Frame, + outputDoneCh chan<- struct{}, + outputErrCh chan<- error, ) { defer func() { - doneCh <- struct{}{} + outputDoneCh <- struct{}{} }() + buffer := make([]byte, 4096) + for { - buffer := make([]byte, 4096) n, err := reader.Read(buffer) if n > 0 { - outputCh <- execstream.Frame{ + frame := execstream.Frame{ Type: frameType, Data: append([]byte(nil), buffer[:n]...), } + + select { + case outputFrameCh <- frame: + case <-ctx.Done(): + return + } } if errors.Is(err, io.EOF) { @@ -436,7 +534,7 @@ func streamExecuteOutput( if err != nil { select { - case errCh <- err: + case outputErrCh <- err: default: } @@ -464,7 +562,6 @@ func shellQuoteArg(arg string) string { } func (controller *Controller) closeExecuteWithFrameError( - ctx context.Context, wsConn *websocket.Conn, encoder *json.Encoder, message string, diff --git a/internal/controller/api_vms_execute_test.go b/internal/controller/api_vms_execute_test.go index 96acd0b..caa9692 100644 --- a/internal/controller/api_vms_execute_test.go +++ b/internal/controller/api_vms_execute_test.go @@ -3,8 +3,11 @@ package controller import ( "bytes" + "context" + "errors" "io" "testing" + "time" "github.com/cirruslabs/orchard/internal/execstream" "github.com/stretchr/testify/require" @@ -21,7 +24,7 @@ func (writer *recordingWriteCloser) Close() error { return nil } -func TestStreamExecuteClientFramesWritesInputAndClosesOnEOFFrame(t *testing.T) { +func TestConsumeClientInputFramesWritesInputAndClosesOnEOFFrame(t *testing.T) { var input bytes.Buffer encoder := execstream.NewEncoder(&input) @@ -41,14 +44,14 @@ func TestStreamExecuteClientFramesWritesInputAndClosesOnEOFFrame(t *testing.T) { stdin := &recordingWriteCloser{} errCh := make(chan error, 1) - streamExecuteClientFrames(decoder, stdin, errCh) + consumeClientInputFrames(decoder, stdin, errCh) require.NoError(t, <-errCh) require.True(t, stdin.closed) require.Equal(t, "hello", stdin.String()) } -func TestStreamExecuteClientFramesUnsupportedType(t *testing.T) { +func TestConsumeClientInputFramesUnsupportedType(t *testing.T) { var input bytes.Buffer encoder := execstream.NewEncoder(&input) @@ -61,33 +64,33 @@ func TestStreamExecuteClientFramesUnsupportedType(t *testing.T) { stdin := &recordingWriteCloser{} errCh := make(chan error, 1) - streamExecuteClientFrames(decoder, stdin, errCh) + consumeClientInputFrames(decoder, stdin, errCh) require.EqualError(t, <-errCh, "unsupported frame type \"stdout\" received from client") require.False(t, stdin.closed) } -func TestStreamExecuteClientFramesClosesStdinOnDecodeError(t *testing.T) { +func TestConsumeClientInputFramesClosesStdinOnDecodeError(t *testing.T) { decoder := execstream.NewDecoder(bytes.NewBuffer(nil)) stdin := &recordingWriteCloser{} errCh := make(chan error, 1) - streamExecuteClientFrames(decoder, stdin, errCh) + consumeClientInputFrames(decoder, stdin, errCh) require.ErrorIs(t, <-errCh, io.EOF) require.True(t, stdin.closed) } -func TestStreamExecuteOutputEmitsFrameAndSignalsDone(t *testing.T) { - outputCh := make(chan execstream.Frame, 1) - doneCh := make(chan struct{}, 1) - errCh := make(chan error, 1) +func TestForwardSSHOutputFramesEmitsFrameAndSignalsDone(t *testing.T) { + outputFrameCh := make(chan execstream.Frame, 1) + outputDoneCh := make(chan struct{}, 1) + outputErrCh := make(chan error, 1) - streamExecuteOutput(bytes.NewBufferString("payload"), - execstream.FrameTypeStderr, outputCh, doneCh, errCh) + forwardSSHOutputFrames(context.Background(), bytes.NewBufferString("payload"), + execstream.FrameTypeStderr, outputFrameCh, outputDoneCh, outputErrCh) select { - case frame := <-outputCh: + case frame := <-outputFrameCh: require.Equal(t, execstream.FrameTypeStderr, frame.Type) require.Equal(t, []byte("payload"), frame.Data) default: @@ -95,18 +98,68 @@ func TestStreamExecuteOutputEmitsFrameAndSignalsDone(t *testing.T) { } select { - case <-doneCh: + case <-outputDoneCh: default: t.Fatal("expected done signal") } select { - case err := <-errCh: + case err := <-outputErrCh: t.Fatalf("unexpected error: %v", err) default: } } +func TestForwardSSHOutputFramesStopsWhenContextCancelledWhileOutputChannelIsBlocked(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + outputFrameCh := make(chan execstream.Frame, 1) + outputFrameCh <- execstream.Frame{ + Type: execstream.FrameTypeStdout, + Data: []byte("occupied"), + } + + outputDoneCh := make(chan struct{}, 1) + outputErrCh := make(chan error, 1) + + finished := make(chan struct{}) + go func() { + forwardSSHOutputFrames(ctx, bytes.NewBufferString("payload"), + execstream.FrameTypeStdout, outputFrameCh, outputDoneCh, outputErrCh) + close(finished) + }() + + select { + case <-finished: + t.Fatal("forwardSSHOutputFrames unexpectedly returned before context cancellation") + case <-time.After(50 * time.Millisecond): + } + + cancel() + + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("forwardSSHOutputFrames did not return after context cancellation") + } + + select { + case <-outputDoneCh: + default: + t.Fatal("expected done signal") + } +} + +func TestFirstExecuteOutputErrorReturnsFirstNonNilError(t *testing.T) { + outputErrCh := make(chan error, 3) + outputErrCh <- nil + outputErrCh <- errors.New("first error") + outputErrCh <- errors.New("second error") + + require.EqualError(t, firstExecuteOutputError(outputErrCh), "first error") +} + func TestBuildSSHCommandQuotesArguments(t *testing.T) { result := buildSSHCommand("echo", []string{"hello world", "a'b", ""})