Split SSH connection and execution to avoid standard input handoff

This commit is contained in:
Nikolay Edigaryev 2026-02-11 00:03:23 +01:00
parent 5111950f11
commit 8da34cbc06
4 changed files with 246 additions and 172 deletions

View File

@ -6,19 +6,16 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"slices"
"strconv"
"time"
"github.com/cirruslabs/orchard/internal/controller/sshexec"
"github.com/cirruslabs/orchard/internal/execstream"
"github.com/cirruslabs/orchard/internal/responder"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/coder/websocket"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
@ -61,6 +58,13 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
}
defer portForwardCancel()
// Establish an SSH connection to a VM
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err))
}
defer exec.Close()
// Upgrade HTTP request to a WebSocket connection
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
@ -76,26 +80,21 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
_ = wsConn.CloseNow()
}()
// Start a goroutine that establishes an SSH connection to a VM and runs a command
sshErrCh := make(chan error, 1)
stdinHandleCh := make(chan io.WriteCloser, 1)
outgoingFrames := make(chan *execstream.Frame)
// Read WebSocket frames
readFramesErrCh := make(chan error, 1)
go func() {
sshErrCh <- controller.execSSH(ctx, portForwardConn, vm, stdin, stdinHandleCh, command, outgoingFrames)
readFramesErrCh <- controller.readFrames(ctx, wsConn, exec.Stdin())
}()
var readFramesErrCh chan error
// Run the command
sshErrCh := make(chan error, 1)
outgoingFrames := make(chan *execstream.Frame)
go func() {
sshErrCh <- exec.Run(ctx, command, outgoingFrames)
}()
for {
select {
case stdinHandle := <-stdinHandleCh:
// SSH session is almost up, we have the standard input handle,
// so we can start a goroutine that reads WebSocket frames
readFramesErrCh = make(chan error, 1)
go func() {
readFramesErrCh <- controller.readFrames(ctx, wsConn, stdinHandle)
}()
case readFramesErr := <-readFramesErrCh:
controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr)
@ -154,7 +153,7 @@ func (controller *Controller) readFrames(
messageType, payloadBytes, err := wsConn.Read(ctx)
if err != nil {
var closeErr websocket.CloseError
if errors.As(err, &closeErr) {
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
return nil
}
@ -196,152 +195,3 @@ func (controller *Controller) readFrames(
}
}
}
func (controller *Controller) execSSH(
ctx context.Context,
portForwardConn net.Conn,
vm *v1.VM,
stdin bool,
stdinHandleCh chan<- io.WriteCloser,
command string,
outgoingFrames chan<- *execstream.Frame,
) error {
// Establish an SSH connection
sshConn, sshChans, sshReqs, err := ssh.NewClientConn(portForwardConn, "", &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
User: vm.SSHUsername(),
Auth: []ssh.AuthMethod{
ssh.Password(vm.SSHPassword()),
},
})
if err != nil {
return fmt.Errorf("failed to a new SSH connection: %w", err)
}
sshClient := ssh.NewClient(sshConn, sshChans, sshReqs)
defer sshClient.Close()
// Create a new SSH session
sshSession, err := sshClient.NewSession()
if err != nil {
return fmt.Errorf("failed to create a new SSH session: %w", err)
}
defer sshSession.Close()
if stdin {
stdin, err := sshSession.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create standard input pipe: %w", err)
}
stdinHandleCh <- stdin
} else {
stdinHandleCh <- nil
}
stdout, err := sshSession.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create standard output pipe: %w", err)
}
stderr, err := sshSession.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create standard error pipe: %w", err)
}
if err := sshSession.Start(command); err != nil {
return fmt.Errorf("failed to start command %q: %w", command, err)
}
// Read bytes from standard output and standard error and stream them as frames
ioGroup, ioGroupCtx := errgroup.WithContext(ctx)
ioGroup.Go(func() error {
return ioStreamReader(ioGroupCtx, stdout, execstream.FrameTypeStdout, outgoingFrames)
})
ioGroup.Go(func() error {
return ioStreamReader(ioGroupCtx, stderr, execstream.FrameTypeStderr, outgoingFrames)
})
sshWaitErrCh := make(chan error, 1)
go func() {
sshWaitErrCh <- sshSession.Wait()
}()
// Wait for SSH command terminate while respecting context
var sshWaitErr error
select {
case sshWaitErr = <-sshWaitErrCh:
// Proceed
case <-ctx.Done():
return ctx.Err()
}
// Wait for the I/O to complete, otherwise we may
// miss some bits of the command's output/error
if err := ioGroup.Wait(); err != nil {
return err
}
// Post an exit event
exitFrame := &execstream.Frame{
Type: execstream.FrameTypeExit,
Exit: execstream.Exit{
Code: 0,
},
}
if sshWaitErr != nil {
var sshExitError *ssh.ExitError
if errors.As(sshWaitErr, &sshExitError) {
exitFrame.Exit.Code = int32(sshExitError.ExitStatus())
} else {
return fmt.Errorf("failed to execute command %q: %w", command, sshWaitErr)
}
}
select {
case outgoingFrames <- exitFrame:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func ioStreamReader(
ctx context.Context,
r io.Reader,
frameType execstream.FrameType,
ch chan<- *execstream.Frame,
) error {
buf := make([]byte, 4096)
for {
n, err := r.Read(buf)
if n > 0 {
frame := &execstream.Frame{
Type: frameType,
Data: slices.Clone(buf[:n]),
}
select {
case <-ctx.Done():
return ctx.Err()
case ch <- frame:
// Proceed
}
}
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
}
}

View File

@ -206,7 +206,7 @@ func (controller *Controller) waitForVM(ctx context.Context, name string) (*v1.V
func (controller *Controller) portForwardConnection(
ctx context.Context,
notifyContext context.Context,
waitContext context.Context,
workerName string,
vmUID string,
port uint32,
@ -223,7 +223,7 @@ func (controller *Controller) portForwardConnection(
}
// Send request to a worker to initiate a port forwarding connection back to us
err := controller.workerNotifier.Notify(notifyContext, workerName, &rpc.WatchInstruction{
err := controller.workerNotifier.Notify(waitContext, workerName, &rpc.WatchInstruction{
Action: &rpc.WatchInstruction_PortForwardAction{
PortForwardAction: &rpc.WatchInstruction_PortForward{
Session: session,
@ -250,9 +250,9 @@ func (controller *Controller) portForwardConnection(
}
return rendezvousResponse.Result, cancel, nil
case <-ctx.Done():
case <-waitContext.Done():
cancel()
return nil, nil, ctx.Err()
return nil, nil, waitContext.Err()
}
}

View File

@ -0,0 +1,198 @@
package sshexec
import (
"context"
"errors"
"fmt"
"io"
"net"
"slices"
"github.com/cirruslabs/orchard/internal/execstream"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
type Exec struct {
sshClient *ssh.Client
sshSession *ssh.Session
stdout io.Reader
stderr io.Reader
stdin io.WriteCloser
}
func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, error) {
// Establish an SSH connection
sshConn, sshChans, sshReqs, err := ssh.NewClientConn(netConn, "", &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
User: user,
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
})
if err != nil {
return nil, fmt.Errorf("failed to create an SSH connection: %w", err)
}
sshClient := ssh.NewClient(sshConn, sshChans, sshReqs)
// Create a new SSH session
sshSession, err := sshClient.NewSession()
if err != nil {
_ = sshClient.Close()
return nil, fmt.Errorf("failed to create an SSH session: %w", err)
}
exec := &Exec{
sshClient: sshClient,
sshSession: sshSession,
}
if stdin {
exec.stdin, err = sshSession.StdinPipe()
if err != nil {
_ = sshSession.Close()
_ = sshClient.Close()
return nil, fmt.Errorf("failed to create standard input pipe "+
"for the SSH session: %w", err)
}
}
exec.stdout, err = sshSession.StdoutPipe()
if err != nil {
_ = sshSession.Close()
_ = sshClient.Close()
return nil, fmt.Errorf("failed to create standard output pipe "+
"for the SSH session: %w", err)
}
exec.stderr, err = sshSession.StderrPipe()
if err != nil {
_ = sshSession.Close()
_ = sshClient.Close()
return nil, fmt.Errorf("failed to create standard error pipe "+
"for the SSH session: %w", err)
}
return exec, nil
}
func (exec *Exec) Stdin() io.WriteCloser {
return exec.stdin
}
func (exec *Exec) Run(
ctx context.Context,
command string,
outgoingFrames chan<- *execstream.Frame,
) error {
if err := exec.sshSession.Start(command); err != nil {
return fmt.Errorf("failed to start command %q: %w", command, err)
}
// Read bytes from standard output and standard error and stream them as frames
ioGroup, ioGroupCtx := errgroup.WithContext(ctx)
ioGroup.Go(func() error {
return ioStreamReader(ioGroupCtx, exec.stdout, execstream.FrameTypeStdout, outgoingFrames)
})
ioGroup.Go(func() error {
return ioStreamReader(ioGroupCtx, exec.stderr, execstream.FrameTypeStderr, outgoingFrames)
})
sshWaitErrCh := make(chan error, 1)
go func() {
sshWaitErrCh <- exec.sshSession.Wait()
}()
// Wait for SSH command terminate while respecting context
var sshWaitErr error
select {
case sshWaitErr = <-sshWaitErrCh:
// Proceed
case <-ctx.Done():
return ctx.Err()
}
// Wait for the I/O to complete, otherwise we may
// miss some bits of the command's output/error
if err := ioGroup.Wait(); err != nil {
return err
}
// Post an exit event
exitFrame := &execstream.Frame{
Type: execstream.FrameTypeExit,
Exit: execstream.Exit{
Code: 0,
},
}
if sshWaitErr != nil {
var sshExitError *ssh.ExitError
if errors.As(sshWaitErr, &sshExitError) {
exitFrame.Exit.Code = int32(sshExitError.ExitStatus())
} else {
return fmt.Errorf("failed to execute command %q: %w", command, sshWaitErr)
}
}
select {
case outgoingFrames <- exitFrame:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func ioStreamReader(
ctx context.Context,
r io.Reader,
frameType execstream.FrameType,
ch chan<- *execstream.Frame,
) error {
buf := make([]byte, 4096)
for {
n, err := r.Read(buf)
if n > 0 {
frame := &execstream.Frame{
Type: frameType,
Data: slices.Clone(buf[:n]),
}
select {
case <-ctx.Done():
return ctx.Err()
case ch <- frame:
// Proceed
}
}
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
}
}
func (exec *Exec) Close() error {
if err := exec.sshSession.Close(); err != nil {
_ = exec.sshClient.Close()
return err
}
return exec.sshClient.Close()
}

View File

@ -0,0 +1,26 @@
package sshexec_test
import (
"net"
"testing"
"time"
"github.com/cirruslabs/orchard/internal/controller/sshexec"
"github.com/stretchr/testify/require"
)
func TestContextCancellationViaNetConnClose(t *testing.T) {
clientConn, serverConn := net.Pipe()
go func() {
select {
case <-t.Context().Done():
return
case <-time.After(5 * time.Second):
require.NoError(t, serverConn.Close())
}
}()
_, err := sshexec.New(clientConn, "doesn't", "matter", false)
require.Error(t, err)
}