Split SSH connection and execution to avoid standard input handoff
This commit is contained in:
parent
5111950f11
commit
8da34cbc06
|
|
@ -6,19 +6,16 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/controller/sshexec"
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
"github.com/cirruslabs/orchard/internal/responder"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
||||
|
|
@ -61,6 +58,13 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
}
|
||||
defer portForwardCancel()
|
||||
|
||||
// Establish an SSH connection to a VM
|
||||
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err))
|
||||
}
|
||||
defer exec.Close()
|
||||
|
||||
// Upgrade HTTP request to a WebSocket connection
|
||||
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
|
|
@ -76,26 +80,21 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
_ = wsConn.CloseNow()
|
||||
}()
|
||||
|
||||
// Start a goroutine that establishes an SSH connection to a VM and runs a command
|
||||
sshErrCh := make(chan error, 1)
|
||||
stdinHandleCh := make(chan io.WriteCloser, 1)
|
||||
outgoingFrames := make(chan *execstream.Frame)
|
||||
// Read WebSocket frames
|
||||
readFramesErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
sshErrCh <- controller.execSSH(ctx, portForwardConn, vm, stdin, stdinHandleCh, command, outgoingFrames)
|
||||
readFramesErrCh <- controller.readFrames(ctx, wsConn, exec.Stdin())
|
||||
}()
|
||||
|
||||
var readFramesErrCh chan error
|
||||
// Run the command
|
||||
sshErrCh := make(chan error, 1)
|
||||
outgoingFrames := make(chan *execstream.Frame)
|
||||
go func() {
|
||||
sshErrCh <- exec.Run(ctx, command, outgoingFrames)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case stdinHandle := <-stdinHandleCh:
|
||||
// SSH session is almost up, we have the standard input handle,
|
||||
// so we can start a goroutine that reads WebSocket frames
|
||||
readFramesErrCh = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
readFramesErrCh <- controller.readFrames(ctx, wsConn, stdinHandle)
|
||||
}()
|
||||
case readFramesErr := <-readFramesErrCh:
|
||||
controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr)
|
||||
|
||||
|
|
@ -154,7 +153,7 @@ func (controller *Controller) readFrames(
|
|||
messageType, payloadBytes, err := wsConn.Read(ctx)
|
||||
if err != nil {
|
||||
var closeErr websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -196,152 +195,3 @@ func (controller *Controller) readFrames(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (controller *Controller) execSSH(
|
||||
ctx context.Context,
|
||||
portForwardConn net.Conn,
|
||||
vm *v1.VM,
|
||||
stdin bool,
|
||||
stdinHandleCh chan<- io.WriteCloser,
|
||||
command string,
|
||||
outgoingFrames chan<- *execstream.Frame,
|
||||
) error {
|
||||
// Establish an SSH connection
|
||||
sshConn, sshChans, sshReqs, err := ssh.NewClientConn(portForwardConn, "", &ssh.ClientConfig{
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
User: vm.SSHUsername(),
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(vm.SSHPassword()),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to a new SSH connection: %w", err)
|
||||
}
|
||||
|
||||
sshClient := ssh.NewClient(sshConn, sshChans, sshReqs)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Create a new SSH session
|
||||
sshSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create a new SSH session: %w", err)
|
||||
}
|
||||
defer sshSession.Close()
|
||||
|
||||
if stdin {
|
||||
stdin, err := sshSession.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create standard input pipe: %w", err)
|
||||
}
|
||||
|
||||
stdinHandleCh <- stdin
|
||||
} else {
|
||||
stdinHandleCh <- nil
|
||||
}
|
||||
|
||||
stdout, err := sshSession.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create standard output pipe: %w", err)
|
||||
}
|
||||
|
||||
stderr, err := sshSession.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create standard error pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := sshSession.Start(command); err != nil {
|
||||
return fmt.Errorf("failed to start command %q: %w", command, err)
|
||||
}
|
||||
|
||||
// Read bytes from standard output and standard error and stream them as frames
|
||||
ioGroup, ioGroupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
ioGroup.Go(func() error {
|
||||
return ioStreamReader(ioGroupCtx, stdout, execstream.FrameTypeStdout, outgoingFrames)
|
||||
})
|
||||
ioGroup.Go(func() error {
|
||||
return ioStreamReader(ioGroupCtx, stderr, execstream.FrameTypeStderr, outgoingFrames)
|
||||
})
|
||||
|
||||
sshWaitErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
sshWaitErrCh <- sshSession.Wait()
|
||||
}()
|
||||
|
||||
// Wait for SSH command terminate while respecting context
|
||||
var sshWaitErr error
|
||||
|
||||
select {
|
||||
case sshWaitErr = <-sshWaitErrCh:
|
||||
// Proceed
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Wait for the I/O to complete, otherwise we may
|
||||
// miss some bits of the command's output/error
|
||||
if err := ioGroup.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Post an exit event
|
||||
exitFrame := &execstream.Frame{
|
||||
Type: execstream.FrameTypeExit,
|
||||
Exit: execstream.Exit{
|
||||
Code: 0,
|
||||
},
|
||||
}
|
||||
|
||||
if sshWaitErr != nil {
|
||||
var sshExitError *ssh.ExitError
|
||||
if errors.As(sshWaitErr, &sshExitError) {
|
||||
exitFrame.Exit.Code = int32(sshExitError.ExitStatus())
|
||||
} else {
|
||||
return fmt.Errorf("failed to execute command %q: %w", command, sshWaitErr)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case outgoingFrames <- exitFrame:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func ioStreamReader(
|
||||
ctx context.Context,
|
||||
r io.Reader,
|
||||
frameType execstream.FrameType,
|
||||
ch chan<- *execstream.Frame,
|
||||
) error {
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
frame := &execstream.Frame{
|
||||
Type: frameType,
|
||||
Data: slices.Clone(buf[:n]),
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- frame:
|
||||
// Proceed
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ func (controller *Controller) waitForVM(ctx context.Context, name string) (*v1.V
|
|||
|
||||
func (controller *Controller) portForwardConnection(
|
||||
ctx context.Context,
|
||||
notifyContext context.Context,
|
||||
waitContext context.Context,
|
||||
workerName string,
|
||||
vmUID string,
|
||||
port uint32,
|
||||
|
|
@ -223,7 +223,7 @@ func (controller *Controller) portForwardConnection(
|
|||
}
|
||||
|
||||
// Send request to a worker to initiate a port forwarding connection back to us
|
||||
err := controller.workerNotifier.Notify(notifyContext, workerName, &rpc.WatchInstruction{
|
||||
err := controller.workerNotifier.Notify(waitContext, workerName, &rpc.WatchInstruction{
|
||||
Action: &rpc.WatchInstruction_PortForwardAction{
|
||||
PortForwardAction: &rpc.WatchInstruction_PortForward{
|
||||
Session: session,
|
||||
|
|
@ -250,9 +250,9 @@ func (controller *Controller) portForwardConnection(
|
|||
}
|
||||
|
||||
return rendezvousResponse.Result, cancel, nil
|
||||
case <-ctx.Done():
|
||||
case <-waitContext.Done():
|
||||
cancel()
|
||||
|
||||
return nil, nil, ctx.Err()
|
||||
return nil, nil, waitContext.Err()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
package sshexec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type Exec struct {
|
||||
sshClient *ssh.Client
|
||||
sshSession *ssh.Session
|
||||
stdout io.Reader
|
||||
stderr io.Reader
|
||||
stdin io.WriteCloser
|
||||
}
|
||||
|
||||
func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, error) {
|
||||
// Establish an SSH connection
|
||||
sshConn, sshChans, sshReqs, err := ssh.NewClientConn(netConn, "", &ssh.ClientConfig{
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
User: user,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(password),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create an SSH connection: %w", err)
|
||||
}
|
||||
|
||||
sshClient := ssh.NewClient(sshConn, sshChans, sshReqs)
|
||||
|
||||
// Create a new SSH session
|
||||
sshSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_ = sshClient.Close()
|
||||
|
||||
return nil, fmt.Errorf("failed to create an SSH session: %w", err)
|
||||
}
|
||||
|
||||
exec := &Exec{
|
||||
sshClient: sshClient,
|
||||
sshSession: sshSession,
|
||||
}
|
||||
|
||||
if stdin {
|
||||
exec.stdin, err = sshSession.StdinPipe()
|
||||
if err != nil {
|
||||
_ = sshSession.Close()
|
||||
_ = sshClient.Close()
|
||||
|
||||
return nil, fmt.Errorf("failed to create standard input pipe "+
|
||||
"for the SSH session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
exec.stdout, err = sshSession.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = sshSession.Close()
|
||||
_ = sshClient.Close()
|
||||
|
||||
return nil, fmt.Errorf("failed to create standard output pipe "+
|
||||
"for the SSH session: %w", err)
|
||||
}
|
||||
|
||||
exec.stderr, err = sshSession.StderrPipe()
|
||||
if err != nil {
|
||||
_ = sshSession.Close()
|
||||
_ = sshClient.Close()
|
||||
|
||||
return nil, fmt.Errorf("failed to create standard error pipe "+
|
||||
"for the SSH session: %w", err)
|
||||
}
|
||||
|
||||
return exec, nil
|
||||
}
|
||||
|
||||
func (exec *Exec) Stdin() io.WriteCloser {
|
||||
return exec.stdin
|
||||
}
|
||||
|
||||
func (exec *Exec) Run(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
outgoingFrames chan<- *execstream.Frame,
|
||||
) error {
|
||||
if err := exec.sshSession.Start(command); err != nil {
|
||||
return fmt.Errorf("failed to start command %q: %w", command, err)
|
||||
}
|
||||
|
||||
// Read bytes from standard output and standard error and stream them as frames
|
||||
ioGroup, ioGroupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
ioGroup.Go(func() error {
|
||||
return ioStreamReader(ioGroupCtx, exec.stdout, execstream.FrameTypeStdout, outgoingFrames)
|
||||
})
|
||||
ioGroup.Go(func() error {
|
||||
return ioStreamReader(ioGroupCtx, exec.stderr, execstream.FrameTypeStderr, outgoingFrames)
|
||||
})
|
||||
|
||||
sshWaitErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
sshWaitErrCh <- exec.sshSession.Wait()
|
||||
}()
|
||||
|
||||
// Wait for SSH command terminate while respecting context
|
||||
var sshWaitErr error
|
||||
|
||||
select {
|
||||
case sshWaitErr = <-sshWaitErrCh:
|
||||
// Proceed
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Wait for the I/O to complete, otherwise we may
|
||||
// miss some bits of the command's output/error
|
||||
if err := ioGroup.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Post an exit event
|
||||
exitFrame := &execstream.Frame{
|
||||
Type: execstream.FrameTypeExit,
|
||||
Exit: execstream.Exit{
|
||||
Code: 0,
|
||||
},
|
||||
}
|
||||
|
||||
if sshWaitErr != nil {
|
||||
var sshExitError *ssh.ExitError
|
||||
if errors.As(sshWaitErr, &sshExitError) {
|
||||
exitFrame.Exit.Code = int32(sshExitError.ExitStatus())
|
||||
} else {
|
||||
return fmt.Errorf("failed to execute command %q: %w", command, sshWaitErr)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case outgoingFrames <- exitFrame:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func ioStreamReader(
|
||||
ctx context.Context,
|
||||
r io.Reader,
|
||||
frameType execstream.FrameType,
|
||||
ch chan<- *execstream.Frame,
|
||||
) error {
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
frame := &execstream.Frame{
|
||||
Type: frameType,
|
||||
Data: slices.Clone(buf[:n]),
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- frame:
|
||||
// Proceed
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (exec *Exec) Close() error {
|
||||
if err := exec.sshSession.Close(); err != nil {
|
||||
_ = exec.sshClient.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return exec.sshClient.Close()
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
package sshexec_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/controller/sshexec"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestContextCancellationViaNetConnClose(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-t.Context().Done():
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
require.NoError(t, serverConn.Close())
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := sshexec.New(clientConn, "doesn't", "matter", false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
Loading…
Reference in New Issue