Refactor executeVM logic and introduce helper functions for request parsing, WebSocket handling, and session execution. Add related tests.

This commit is contained in:
Fedor Korotkov 2026-02-08 20:41:05 +01:00
parent dc0f5b45d0
commit 9c55014cc8
2 changed files with 264 additions and 114 deletions

View File

@ -32,44 +32,98 @@ type sshExecution struct {
stdin io.WriteCloser
}
type executeVMRequest struct {
name string
command string
args []string
wait time.Duration
}
type executeSessionChannels struct {
outputFrameCh chan execstream.Frame
outputDoneCh chan struct{}
outputErrCh chan error
stdinErrCh chan error
exitCodeCh chan int32
exitErrCh chan error
}
func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder {
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
v1.ServiceAccountRoleComputeConnect); responder != nil {
return responder
}
name := ctx.Param("name")
command := ctx.Query("command")
if command == "" {
return responder.Code(http.StatusBadRequest)
}
args := ctx.QueryArray("arg")
waitRaw := ctx.DefaultQuery("wait", "10")
wait, err := strconv.ParseUint(waitRaw, 10, 16)
if err != nil {
return responder.Code(http.StatusBadRequest)
}
waitCtx, waitCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitCancel()
vm, responderImpl := controller.waitForVM(waitCtx, name)
request, responderImpl := parseExecuteVMRequest(ctx)
if responderImpl != nil {
return responderImpl
}
rvCtx, rvCancel := context.WithCancel(ctx)
defer rvCancel()
waitCtx, waitCancel := context.WithTimeout(ctx, request.wait)
defer waitCancel()
vm, responderImpl := controller.waitForVM(waitCtx, request.name)
if responderImpl != nil {
return responderImpl
}
tunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, waitCtx, vm)
if responderImpl != nil {
return responderImpl
}
defer func() {
_ = tunnel.Close()
}()
wsConn, err := acceptExecuteWebSocket(ctx)
if err != nil {
return responder.Error(err)
}
defer func() {
_ = wsConn.CloseNow()
}()
return controller.executeVMViaSSHTunnel(ctx, tunnel, wsConn, vm, request.command, request.args)
}
func parseExecuteVMRequest(ctx *gin.Context) (*executeVMRequest, responder.Responder) {
command := ctx.Query("command")
if command == "" {
return nil, responder.Code(http.StatusBadRequest)
}
waitRaw := ctx.DefaultQuery("wait", "10")
wait, err := strconv.ParseUint(waitRaw, 10, 16)
if err != nil {
return nil, responder.Code(http.StatusBadRequest)
}
return &executeVMRequest{
name: ctx.Param("name"),
command: command,
args: ctx.QueryArray("arg"),
wait: time.Duration(wait) * time.Second,
}, nil
}
func acceptExecuteWebSocket(ctx *gin.Context) (*websocket.Conn, error) {
return websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
}
func (controller *Controller) establishExecuteSSHTunnel(
ctx context.Context,
waitCtx context.Context,
vm *v1.VM,
) (net.Conn, responder.Responder) {
rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx)
session := uuid.New().String()
connCh, cancelRequest := controller.connRendezvous.Request(rendezvousCtx, session)
defer cancelRequest()
connCh, cancel := controller.connRendezvous.Request(rvCtx, session)
defer cancel()
err = controller.workerNotifier.Notify(waitCtx, vm.Worker, &rpc.WatchInstruction{
err := controller.workerNotifier.Notify(waitCtx, vm.Worker, &rpc.WatchInstruction{
Action: &rpc.WatchInstruction_PortForwardAction{
PortForwardAction: &rpc.WatchInstruction_PortForward{
Session: session,
@ -79,56 +133,49 @@ func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder {
},
})
if err != nil {
rendezvousCtxCancel()
controller.logger.Warnf("failed to request VM SSH port-forwarding from the worker %s: %v",
vm.Worker, err)
return responder.Code(http.StatusServiceUnavailable)
return nil, responder.Code(http.StatusServiceUnavailable)
}
timeoutTimer := time.NewTimer(executeSessionRendezvousTimeout)
defer timeoutTimer.Stop()
select {
case rvResp := <-connCh:
if rvResp.ErrorMessage != "" {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
"failed to establish SSH connection to the VM on the worker: %s", rvResp.ErrorMessage))
case rendezvousResp := <-connCh:
if rendezvousResp.ErrorMessage != "" {
rendezvousCtxCancel()
return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
"failed to establish SSH connection to the VM on the worker: %s", rendezvousResp.ErrorMessage))
}
if rvResp.Result == nil {
return responder.Code(http.StatusServiceUnavailable)
if rendezvousResp.Result == nil {
rendezvousCtxCancel()
return nil, responder.Code(http.StatusServiceUnavailable)
}
ws, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
_ = rvResp.Result.Close()
return responder.Error(err)
}
defer func() {
_ = ws.CloseNow()
}()
tunnel := netconncancel.New(rvResp.Result, rvCancel)
defer func() {
_ = tunnel.Close()
}()
return controller.executeVMViaSSHTunnel(ctx, tunnel, ws, vm, command, args)
return netconncancel.New(rendezvousResp.Result, rendezvousCtxCancel), nil
case <-timeoutTimer.C:
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
rendezvousCtxCancel()
return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
"timed out waiting for worker %s to establish SSH tunnel", vm.Worker))
case <-ctx.Done():
return responder.Error(ctx.Err())
rendezvousCtxCancel()
return nil, responder.Error(ctx.Err())
}
}
func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel net.Conn, ws *websocket.Conn, vm *v1.VM, cmd string, args []string) responder.Responder {
sshClient, err := newSSHClient(tunnel, vm)
if err != nil {
controller.closeExecuteWithFrameError(ctx, ws, nil,
controller.closeExecuteWithFrameError(ws, nil,
fmt.Sprintf("SSH handshake with the VM failed: %v", err))
return responder.Empty()
@ -139,7 +186,7 @@ func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel
execution, err := startSSHExecution(sshClient, cmd, args)
if err != nil {
controller.closeExecuteWithFrameError(ctx, ws, nil, err.Error())
controller.closeExecuteWithFrameError(ws, nil, err.Error())
return responder.Empty()
}
@ -147,7 +194,7 @@ func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel
_ = execution.session.Close()
}()
return controller.pumpExecuteFrames(ctx, ws, execution)
return controller.runExecuteSession(ctx, ws, execution)
}
func newSSHClient(conn net.Conn, vm *v1.VM) (*ssh.Client, error) {
@ -217,12 +264,61 @@ func startSSHExecution(client *ssh.Client, cmd string, args []string) (*sshExecu
}, nil
}
func (controller *Controller) pumpExecuteFrames(
func startExecuteSessionChannels(
sessionCtx context.Context,
decoder *json.Decoder,
execution *sshExecution,
) executeSessionChannels {
channels := executeSessionChannels{
outputFrameCh: make(chan execstream.Frame, 16),
outputDoneCh: make(chan struct{}, 2),
outputErrCh: make(chan error, 2),
stdinErrCh: make(chan error, 1),
exitCodeCh: make(chan int32, 1),
exitErrCh: make(chan error, 1),
}
go forwardSSHOutputFrames(sessionCtx, execution.stdout, execstream.FrameTypeStdout,
channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh)
go forwardSSHOutputFrames(sessionCtx, execution.stderr, execstream.FrameTypeStderr,
channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh)
go consumeClientInputFrames(decoder, execution.stdin, channels.stdinErrCh)
go waitForSSHExecutionExit(execution.session, channels.exitCodeCh, channels.exitErrCh)
return channels
}
func flushExecuteOutputFrames(encoder *json.Encoder, outputFrameCh <-chan execstream.Frame) error {
for len(outputFrameCh) > 0 {
frame := <-outputFrameCh
if err := execstream.WriteFrame(encoder, &frame); err != nil {
return err
}
}
return nil
}
func firstExecuteOutputError(outputErrCh <-chan error) error {
for len(outputErrCh) > 0 {
err := <-outputErrCh
if err != nil {
return err
}
}
return nil
}
func (controller *Controller) runExecuteSession(
ctx context.Context,
ws *websocket.Conn,
execution *sshExecution,
) responder.Responder {
wsNetConn := websocket.NetConn(ctx, ws, websocket.MessageText)
sessionCtx, sessionCancel := context.WithCancel(ctx)
defer sessionCancel()
wsNetConn := websocket.NetConn(sessionCtx, ws, websocket.MessageText)
defer func() {
_ = wsNetConn.Close()
}()
@ -230,33 +326,27 @@ func (controller *Controller) pumpExecuteFrames(
encoder := execstream.NewEncoder(wsNetConn)
decoder := execstream.NewDecoder(wsNetConn)
outCh := make(chan execstream.Frame, 16)
outDoneCh := make(chan struct{}, 2)
outErrCh := make(chan error, 1)
stdinErrCh := make(chan error, 1)
exitCh := make(chan int32, 1)
exitErrCh := make(chan error, 1)
go streamExecuteOutput(execution.stdout, execstream.FrameTypeStdout, outCh, outDoneCh, outErrCh)
go streamExecuteOutput(execution.stderr, execstream.FrameTypeStderr, outCh, outDoneCh, outErrCh)
go streamExecuteClientFrames(decoder, execution.stdin, stdinErrCh)
go waitForSSHExecutionExit(execution.session, exitCh, exitErrCh)
channels := startExecuteSessionChannels(sessionCtx, decoder, execution)
pingTicker := time.NewTicker(controller.pingInterval)
defer pingTicker.Stop()
readersDone := 0
exitObserved := false
outputReadersDone := 0
exitSeen := false
var exitCode int32
for {
if exitObserved && readersDone >= 2 {
for len(outCh) > 0 {
frame := <-outCh
if err := execstream.WriteFrame(encoder, &frame); err != nil {
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
"failed to stream execute output to the client", err)
}
if exitSeen && outputReadersDone >= 2 {
if outputErr := firstExecuteOutputError(channels.outputErrCh); outputErr != nil {
controller.closeExecuteWithFrameError(ws, encoder,
fmt.Sprintf("failed while streaming command output: %v", outputErr))
return responder.Empty()
}
if err := flushExecuteOutputFrames(encoder, channels.outputFrameCh); err != nil {
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
"failed to stream execute output to the client", err)
}
if err := execstream.WriteFrame(encoder, &execstream.Frame{
@ -276,25 +366,25 @@ func (controller *Controller) pumpExecuteFrames(
}
select {
case frame := <-outCh:
case frame := <-channels.outputFrameCh:
if err := execstream.WriteFrame(encoder, &frame); err != nil {
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
"failed to stream execute output to the client", err)
}
case <-outDoneCh:
readersDone++
case err := <-outErrCh:
case <-channels.outputDoneCh:
outputReadersDone++
case err := <-channels.outputErrCh:
if err == nil {
continue
}
controller.closeExecuteWithFrameError(ctx, ws, encoder,
controller.closeExecuteWithFrameError(ws, encoder,
fmt.Sprintf("failed while streaming command output: %v", err))
return responder.Empty()
case err := <-stdinErrCh:
case err := <-channels.stdinErrCh:
if err == nil || errors.Is(err, context.Canceled) {
stdinErrCh = nil
channels.stdinErrCh = nil
continue
}
@ -302,15 +392,15 @@ func (controller *Controller) pumpExecuteFrames(
return responder.Empty()
}
controller.closeExecuteWithFrameError(ctx, ws, encoder,
controller.closeExecuteWithFrameError(ws, encoder,
fmt.Sprintf("failed while reading command stdin stream: %v", err))
return responder.Empty()
case code := <-exitCh:
exitObserved = true
case code := <-channels.exitCodeCh:
exitSeen = true
exitCode = code
case err := <-exitErrCh:
controller.closeExecuteWithFrameError(ctx, ws, encoder,
case err := <-channels.exitErrCh:
controller.closeExecuteWithFrameError(ws, encoder,
fmt.Sprintf("failed while waiting for command completion: %v", err))
return responder.Empty()
@ -346,7 +436,7 @@ func waitForSSHExecutionExit(sshSession *ssh.Session, exitCodeCh chan<- int32, e
exitCodeCh <- 0
}
func streamExecuteClientFrames(
func consumeClientInputFrames(
decoder *json.Decoder,
stdin io.WriteCloser,
errCh chan<- error,
@ -409,25 +499,33 @@ func streamExecuteClientFrames(
}
}
func streamExecuteOutput(
func forwardSSHOutputFrames(
ctx context.Context,
reader io.Reader,
frameType execstream.FrameType,
outputCh chan<- execstream.Frame,
doneCh chan<- struct{},
errCh chan<- error,
outputFrameCh chan<- execstream.Frame,
outputDoneCh chan<- struct{},
outputErrCh chan<- error,
) {
defer func() {
doneCh <- struct{}{}
outputDoneCh <- struct{}{}
}()
buffer := make([]byte, 4096)
for {
buffer := make([]byte, 4096)
n, err := reader.Read(buffer)
if n > 0 {
outputCh <- execstream.Frame{
frame := execstream.Frame{
Type: frameType,
Data: append([]byte(nil), buffer[:n]...),
}
select {
case outputFrameCh <- frame:
case <-ctx.Done():
return
}
}
if errors.Is(err, io.EOF) {
@ -436,7 +534,7 @@ func streamExecuteOutput(
if err != nil {
select {
case errCh <- err:
case outputErrCh <- err:
default:
}
@ -464,7 +562,6 @@ func shellQuoteArg(arg string) string {
}
func (controller *Controller) closeExecuteWithFrameError(
ctx context.Context,
wsConn *websocket.Conn,
encoder *json.Encoder,
message string,

View File

@ -3,8 +3,11 @@ package controller
import (
"bytes"
"context"
"errors"
"io"
"testing"
"time"
"github.com/cirruslabs/orchard/internal/execstream"
"github.com/stretchr/testify/require"
@ -21,7 +24,7 @@ func (writer *recordingWriteCloser) Close() error {
return nil
}
func TestStreamExecuteClientFramesWritesInputAndClosesOnEOFFrame(t *testing.T) {
func TestConsumeClientInputFramesWritesInputAndClosesOnEOFFrame(t *testing.T) {
var input bytes.Buffer
encoder := execstream.NewEncoder(&input)
@ -41,14 +44,14 @@ func TestStreamExecuteClientFramesWritesInputAndClosesOnEOFFrame(t *testing.T) {
stdin := &recordingWriteCloser{}
errCh := make(chan error, 1)
streamExecuteClientFrames(decoder, stdin, errCh)
consumeClientInputFrames(decoder, stdin, errCh)
require.NoError(t, <-errCh)
require.True(t, stdin.closed)
require.Equal(t, "hello", stdin.String())
}
func TestStreamExecuteClientFramesUnsupportedType(t *testing.T) {
func TestConsumeClientInputFramesUnsupportedType(t *testing.T) {
var input bytes.Buffer
encoder := execstream.NewEncoder(&input)
@ -61,33 +64,33 @@ func TestStreamExecuteClientFramesUnsupportedType(t *testing.T) {
stdin := &recordingWriteCloser{}
errCh := make(chan error, 1)
streamExecuteClientFrames(decoder, stdin, errCh)
consumeClientInputFrames(decoder, stdin, errCh)
require.EqualError(t, <-errCh, "unsupported frame type \"stdout\" received from client")
require.False(t, stdin.closed)
}
func TestStreamExecuteClientFramesClosesStdinOnDecodeError(t *testing.T) {
func TestConsumeClientInputFramesClosesStdinOnDecodeError(t *testing.T) {
decoder := execstream.NewDecoder(bytes.NewBuffer(nil))
stdin := &recordingWriteCloser{}
errCh := make(chan error, 1)
streamExecuteClientFrames(decoder, stdin, errCh)
consumeClientInputFrames(decoder, stdin, errCh)
require.ErrorIs(t, <-errCh, io.EOF)
require.True(t, stdin.closed)
}
func TestStreamExecuteOutputEmitsFrameAndSignalsDone(t *testing.T) {
outputCh := make(chan execstream.Frame, 1)
doneCh := make(chan struct{}, 1)
errCh := make(chan error, 1)
func TestForwardSSHOutputFramesEmitsFrameAndSignalsDone(t *testing.T) {
outputFrameCh := make(chan execstream.Frame, 1)
outputDoneCh := make(chan struct{}, 1)
outputErrCh := make(chan error, 1)
streamExecuteOutput(bytes.NewBufferString("payload"),
execstream.FrameTypeStderr, outputCh, doneCh, errCh)
forwardSSHOutputFrames(context.Background(), bytes.NewBufferString("payload"),
execstream.FrameTypeStderr, outputFrameCh, outputDoneCh, outputErrCh)
select {
case frame := <-outputCh:
case frame := <-outputFrameCh:
require.Equal(t, execstream.FrameTypeStderr, frame.Type)
require.Equal(t, []byte("payload"), frame.Data)
default:
@ -95,18 +98,68 @@ func TestStreamExecuteOutputEmitsFrameAndSignalsDone(t *testing.T) {
}
select {
case <-doneCh:
case <-outputDoneCh:
default:
t.Fatal("expected done signal")
}
select {
case err := <-errCh:
case err := <-outputErrCh:
t.Fatalf("unexpected error: %v", err)
default:
}
}
func TestForwardSSHOutputFramesStopsWhenContextCancelledWhileOutputChannelIsBlocked(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
outputFrameCh := make(chan execstream.Frame, 1)
outputFrameCh <- execstream.Frame{
Type: execstream.FrameTypeStdout,
Data: []byte("occupied"),
}
outputDoneCh := make(chan struct{}, 1)
outputErrCh := make(chan error, 1)
finished := make(chan struct{})
go func() {
forwardSSHOutputFrames(ctx, bytes.NewBufferString("payload"),
execstream.FrameTypeStdout, outputFrameCh, outputDoneCh, outputErrCh)
close(finished)
}()
select {
case <-finished:
t.Fatal("forwardSSHOutputFrames unexpectedly returned before context cancellation")
case <-time.After(50 * time.Millisecond):
}
cancel()
select {
case <-finished:
case <-time.After(time.Second):
t.Fatal("forwardSSHOutputFrames did not return after context cancellation")
}
select {
case <-outputDoneCh:
default:
t.Fatal("expected done signal")
}
}
func TestFirstExecuteOutputErrorReturnsFirstNonNilError(t *testing.T) {
outputErrCh := make(chan error, 3)
outputErrCh <- nil
outputErrCh <- errors.New("first error")
outputErrCh <- errors.New("second error")
require.EqualError(t, firstExecuteOutputError(outputErrCh), "first error")
}
func TestBuildSSHCommandQuotesArguments(t *testing.T) {
result := buildSSHCommand("echo", []string{"hello world", "a'b", ""})