From 8da34cbc06424cbdb57afe8bc4fc68034889be9b Mon Sep 17 00:00:00 2001 From: Nikolay Edigaryev Date: Wed, 11 Feb 2026 00:03:23 +0100 Subject: [PATCH] Split SSH connection and execution to avoid standard input handoff --- internal/controller/api_vms_exec.go | 186 ++---------------- internal/controller/api_vms_portforward.go | 8 +- internal/controller/sshexec/sshexec.go | 198 ++++++++++++++++++++ internal/controller/sshexec/sshexec_test.go | 26 +++ 4 files changed, 246 insertions(+), 172 deletions(-) create mode 100644 internal/controller/sshexec/sshexec.go create mode 100644 internal/controller/sshexec/sshexec_test.go diff --git a/internal/controller/api_vms_exec.go b/internal/controller/api_vms_exec.go index 78ba8e6..9dd107b 100644 --- a/internal/controller/api_vms_exec.go +++ b/internal/controller/api_vms_exec.go @@ -6,19 +6,16 @@ import ( "errors" "fmt" "io" - "net" "net/http" - "slices" "strconv" "time" + "github.com/cirruslabs/orchard/internal/controller/sshexec" "github.com/cirruslabs/orchard/internal/execstream" "github.com/cirruslabs/orchard/internal/responder" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/coder/websocket" "github.com/gin-gonic/gin" - "golang.org/x/crypto/ssh" - "golang.org/x/sync/errgroup" ) func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { @@ -61,6 +58,13 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { } defer portForwardCancel() + // Establish an SSH connection to a VM + exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin) + if err != nil { + return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err)) + } + defer exec.Close() + // Upgrade HTTP request to a WebSocket connection wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, @@ -76,26 +80,21 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { _ = wsConn.CloseNow() }() - // Start a goroutine that establishes an SSH connection to a VM and runs a command - sshErrCh := make(chan error, 1) - stdinHandleCh := make(chan io.WriteCloser, 1) - outgoingFrames := make(chan *execstream.Frame) + // Read WebSocket frames + readFramesErrCh := make(chan error, 1) go func() { - sshErrCh <- controller.execSSH(ctx, portForwardConn, vm, stdin, stdinHandleCh, command, outgoingFrames) + readFramesErrCh <- controller.readFrames(ctx, wsConn, exec.Stdin()) }() - var readFramesErrCh chan error + // Run the command + sshErrCh := make(chan error, 1) + outgoingFrames := make(chan *execstream.Frame) + go func() { + sshErrCh <- exec.Run(ctx, command, outgoingFrames) + }() for { select { - case stdinHandle := <-stdinHandleCh: - // SSH session is almost up, we have the standard input handle, - // so we can start a goroutine that reads WebSocket frames - readFramesErrCh = make(chan error, 1) - - go func() { - readFramesErrCh <- controller.readFrames(ctx, wsConn, stdinHandle) - }() case readFramesErr := <-readFramesErrCh: controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr) @@ -154,7 +153,7 @@ func (controller *Controller) readFrames( messageType, payloadBytes, err := wsConn.Read(ctx) if err != nil { var closeErr websocket.CloseError - if errors.As(err, &closeErr) { + if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure { return nil } @@ -196,152 +195,3 @@ func (controller *Controller) readFrames( } } } - -func (controller *Controller) execSSH( - ctx context.Context, - portForwardConn net.Conn, - vm *v1.VM, - stdin bool, - stdinHandleCh chan<- io.WriteCloser, - command string, - outgoingFrames chan<- *execstream.Frame, -) error { - // Establish an SSH connection - sshConn, sshChans, sshReqs, err := ssh.NewClientConn(portForwardConn, "", &ssh.ClientConfig{ - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - return nil - }, - User: vm.SSHUsername(), - Auth: []ssh.AuthMethod{ - ssh.Password(vm.SSHPassword()), - }, - }) - if err != nil { - return fmt.Errorf("failed to a new SSH connection: %w", err) - } - - sshClient := ssh.NewClient(sshConn, sshChans, sshReqs) - defer sshClient.Close() - - // Create a new SSH session - sshSession, err := sshClient.NewSession() - if err != nil { - return fmt.Errorf("failed to create a new SSH session: %w", err) - } - defer sshSession.Close() - - if stdin { - stdin, err := sshSession.StdinPipe() - if err != nil { - return fmt.Errorf("failed to create standard input pipe: %w", err) - } - - stdinHandleCh <- stdin - } else { - stdinHandleCh <- nil - } - - stdout, err := sshSession.StdoutPipe() - if err != nil { - return fmt.Errorf("failed to create standard output pipe: %w", err) - } - - stderr, err := sshSession.StderrPipe() - if err != nil { - return fmt.Errorf("failed to create standard error pipe: %w", err) - } - - if err := sshSession.Start(command); err != nil { - return fmt.Errorf("failed to start command %q: %w", command, err) - } - - // Read bytes from standard output and standard error and stream them as frames - ioGroup, ioGroupCtx := errgroup.WithContext(ctx) - - ioGroup.Go(func() error { - return ioStreamReader(ioGroupCtx, stdout, execstream.FrameTypeStdout, outgoingFrames) - }) - ioGroup.Go(func() error { - return ioStreamReader(ioGroupCtx, stderr, execstream.FrameTypeStderr, outgoingFrames) - }) - - sshWaitErrCh := make(chan error, 1) - go func() { - sshWaitErrCh <- sshSession.Wait() - }() - - // Wait for SSH command terminate while respecting context - var sshWaitErr error - - select { - case sshWaitErr = <-sshWaitErrCh: - // Proceed - case <-ctx.Done(): - return ctx.Err() - } - - // Wait for the I/O to complete, otherwise we may - // miss some bits of the command's output/error - if err := ioGroup.Wait(); err != nil { - return err - } - - // Post an exit event - exitFrame := &execstream.Frame{ - Type: execstream.FrameTypeExit, - Exit: execstream.Exit{ - Code: 0, - }, - } - - if sshWaitErr != nil { - var sshExitError *ssh.ExitError - if errors.As(sshWaitErr, &sshExitError) { - exitFrame.Exit.Code = int32(sshExitError.ExitStatus()) - } else { - return fmt.Errorf("failed to execute command %q: %w", command, sshWaitErr) - } - } - - select { - case outgoingFrames <- exitFrame: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -func ioStreamReader( - ctx context.Context, - r io.Reader, - frameType execstream.FrameType, - ch chan<- *execstream.Frame, -) error { - buf := make([]byte, 4096) - - for { - n, err := r.Read(buf) - - if n > 0 { - frame := &execstream.Frame{ - Type: frameType, - Data: slices.Clone(buf[:n]), - } - - select { - case <-ctx.Done(): - return ctx.Err() - case ch <- frame: - // Proceed - } - } - - if err != nil { - if errors.Is(err, io.EOF) { - return nil - } - - return err - } - } -} diff --git a/internal/controller/api_vms_portforward.go b/internal/controller/api_vms_portforward.go index dcc807c..20d085a 100644 --- a/internal/controller/api_vms_portforward.go +++ b/internal/controller/api_vms_portforward.go @@ -206,7 +206,7 @@ func (controller *Controller) waitForVM(ctx context.Context, name string) (*v1.V func (controller *Controller) portForwardConnection( ctx context.Context, - notifyContext context.Context, + waitContext context.Context, workerName string, vmUID string, port uint32, @@ -223,7 +223,7 @@ func (controller *Controller) portForwardConnection( } // Send request to a worker to initiate a port forwarding connection back to us - err := controller.workerNotifier.Notify(notifyContext, workerName, &rpc.WatchInstruction{ + err := controller.workerNotifier.Notify(waitContext, workerName, &rpc.WatchInstruction{ Action: &rpc.WatchInstruction_PortForwardAction{ PortForwardAction: &rpc.WatchInstruction_PortForward{ Session: session, @@ -250,9 +250,9 @@ func (controller *Controller) portForwardConnection( } return rendezvousResponse.Result, cancel, nil - case <-ctx.Done(): + case <-waitContext.Done(): cancel() - return nil, nil, ctx.Err() + return nil, nil, waitContext.Err() } } diff --git a/internal/controller/sshexec/sshexec.go b/internal/controller/sshexec/sshexec.go new file mode 100644 index 0000000..b5bf1f4 --- /dev/null +++ b/internal/controller/sshexec/sshexec.go @@ -0,0 +1,198 @@ +package sshexec + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "slices" + + "github.com/cirruslabs/orchard/internal/execstream" + "golang.org/x/crypto/ssh" + "golang.org/x/sync/errgroup" +) + +type Exec struct { + sshClient *ssh.Client + sshSession *ssh.Session + stdout io.Reader + stderr io.Reader + stdin io.WriteCloser +} + +func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, error) { + // Establish an SSH connection + sshConn, sshChans, sshReqs, err := ssh.NewClientConn(netConn, "", &ssh.ClientConfig{ + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + User: user, + Auth: []ssh.AuthMethod{ + ssh.Password(password), + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to create an SSH connection: %w", err) + } + + sshClient := ssh.NewClient(sshConn, sshChans, sshReqs) + + // Create a new SSH session + sshSession, err := sshClient.NewSession() + if err != nil { + _ = sshClient.Close() + + return nil, fmt.Errorf("failed to create an SSH session: %w", err) + } + + exec := &Exec{ + sshClient: sshClient, + sshSession: sshSession, + } + + if stdin { + exec.stdin, err = sshSession.StdinPipe() + if err != nil { + _ = sshSession.Close() + _ = sshClient.Close() + + return nil, fmt.Errorf("failed to create standard input pipe "+ + "for the SSH session: %w", err) + } + } + + exec.stdout, err = sshSession.StdoutPipe() + if err != nil { + _ = sshSession.Close() + _ = sshClient.Close() + + return nil, fmt.Errorf("failed to create standard output pipe "+ + "for the SSH session: %w", err) + } + + exec.stderr, err = sshSession.StderrPipe() + if err != nil { + _ = sshSession.Close() + _ = sshClient.Close() + + return nil, fmt.Errorf("failed to create standard error pipe "+ + "for the SSH session: %w", err) + } + + return exec, nil +} + +func (exec *Exec) Stdin() io.WriteCloser { + return exec.stdin +} + +func (exec *Exec) Run( + ctx context.Context, + command string, + outgoingFrames chan<- *execstream.Frame, +) error { + if err := exec.sshSession.Start(command); err != nil { + return fmt.Errorf("failed to start command %q: %w", command, err) + } + + // Read bytes from standard output and standard error and stream them as frames + ioGroup, ioGroupCtx := errgroup.WithContext(ctx) + + ioGroup.Go(func() error { + return ioStreamReader(ioGroupCtx, exec.stdout, execstream.FrameTypeStdout, outgoingFrames) + }) + ioGroup.Go(func() error { + return ioStreamReader(ioGroupCtx, exec.stderr, execstream.FrameTypeStderr, outgoingFrames) + }) + + sshWaitErrCh := make(chan error, 1) + go func() { + sshWaitErrCh <- exec.sshSession.Wait() + }() + + // Wait for SSH command terminate while respecting context + var sshWaitErr error + + select { + case sshWaitErr = <-sshWaitErrCh: + // Proceed + case <-ctx.Done(): + return ctx.Err() + } + + // Wait for the I/O to complete, otherwise we may + // miss some bits of the command's output/error + if err := ioGroup.Wait(); err != nil { + return err + } + + // Post an exit event + exitFrame := &execstream.Frame{ + Type: execstream.FrameTypeExit, + Exit: execstream.Exit{ + Code: 0, + }, + } + + if sshWaitErr != nil { + var sshExitError *ssh.ExitError + if errors.As(sshWaitErr, &sshExitError) { + exitFrame.Exit.Code = int32(sshExitError.ExitStatus()) + } else { + return fmt.Errorf("failed to execute command %q: %w", command, sshWaitErr) + } + } + + select { + case outgoingFrames <- exitFrame: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func ioStreamReader( + ctx context.Context, + r io.Reader, + frameType execstream.FrameType, + ch chan<- *execstream.Frame, +) error { + buf := make([]byte, 4096) + + for { + n, err := r.Read(buf) + + if n > 0 { + frame := &execstream.Frame{ + Type: frameType, + Data: slices.Clone(buf[:n]), + } + + select { + case <-ctx.Done(): + return ctx.Err() + case ch <- frame: + // Proceed + } + } + + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + + return err + } + } +} + +func (exec *Exec) Close() error { + if err := exec.sshSession.Close(); err != nil { + _ = exec.sshClient.Close() + + return err + } + + return exec.sshClient.Close() +} diff --git a/internal/controller/sshexec/sshexec_test.go b/internal/controller/sshexec/sshexec_test.go new file mode 100644 index 0000000..bd360ea --- /dev/null +++ b/internal/controller/sshexec/sshexec_test.go @@ -0,0 +1,26 @@ +package sshexec_test + +import ( + "net" + "testing" + "time" + + "github.com/cirruslabs/orchard/internal/controller/sshexec" + "github.com/stretchr/testify/require" +) + +func TestContextCancellationViaNetConnClose(t *testing.T) { + clientConn, serverConn := net.Pipe() + + go func() { + select { + case <-t.Context().Done(): + return + case <-time.After(5 * time.Second): + require.NoError(t, serverConn.Close()) + } + }() + + _, err := sshexec.New(clientConn, "doesn't", "matter", false) + require.Error(t, err) +}