651 lines
16 KiB
Go
651 lines
16 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/cirruslabs/orchard/internal/execstream"
|
|
"github.com/cirruslabs/orchard/internal/netconncancel"
|
|
"github.com/cirruslabs/orchard/internal/responder"
|
|
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
"github.com/cirruslabs/orchard/rpc"
|
|
"github.com/coder/websocket"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
const (
|
|
executeSessionRendezvousTimeout = 15 * time.Second
|
|
executeSessionSSHRetryDelay = 1 * time.Second
|
|
)
|
|
|
|
type sshExecution struct {
|
|
session *ssh.Session
|
|
stdout io.Reader
|
|
stderr io.Reader
|
|
stdin io.WriteCloser
|
|
}
|
|
|
|
type executeVMRequest struct {
|
|
name string
|
|
command string
|
|
args []string
|
|
wait time.Duration
|
|
}
|
|
|
|
type executeSessionChannels struct {
|
|
outputFrameCh chan execstream.Frame
|
|
outputDoneCh chan struct{}
|
|
outputErrCh chan error
|
|
stdinErrCh chan error
|
|
exitCodeCh chan int32
|
|
exitErrCh chan error
|
|
}
|
|
|
|
func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder {
|
|
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
|
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
|
return responder
|
|
}
|
|
|
|
request, responderImpl := parseExecuteVMRequest(ctx)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
waitCtx, waitCancel := context.WithTimeout(ctx, request.wait)
|
|
defer waitCancel()
|
|
|
|
vm, responderImpl := controller.waitForVM(waitCtx, request.name)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
tunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, vm)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
wsConn, err := acceptExecuteWebSocket(ctx)
|
|
if err != nil {
|
|
_ = tunnel.Close()
|
|
return responder.Error(err)
|
|
}
|
|
defer func() {
|
|
_ = wsConn.CloseNow()
|
|
}()
|
|
|
|
return controller.executeVMViaSSHTunnel(ctx, tunnel, wsConn, vm, request.command, request.args)
|
|
}
|
|
|
|
func parseExecuteVMRequest(ctx *gin.Context) (*executeVMRequest, responder.Responder) {
|
|
command := ctx.Query("command")
|
|
if command == "" {
|
|
return nil, responder.Code(http.StatusBadRequest)
|
|
}
|
|
|
|
waitRaw := ctx.DefaultQuery("wait", "10")
|
|
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
|
if err != nil {
|
|
return nil, responder.Code(http.StatusBadRequest)
|
|
}
|
|
|
|
return &executeVMRequest{
|
|
name: ctx.Param("name"),
|
|
command: command,
|
|
args: ctx.QueryArray("arg"),
|
|
wait: time.Duration(wait) * time.Second,
|
|
}, nil
|
|
}
|
|
|
|
func acceptExecuteWebSocket(ctx *gin.Context) (*websocket.Conn, error) {
|
|
return websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
}
|
|
|
|
func (controller *Controller) establishExecuteSSHTunnel(
|
|
ctx context.Context,
|
|
vm *v1.VM,
|
|
) (net.Conn, responder.Responder) {
|
|
tunnelWaitCtx, tunnelWaitCtxCancel := context.WithTimeout(ctx, executeSessionRendezvousTimeout)
|
|
defer tunnelWaitCtxCancel()
|
|
|
|
rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx)
|
|
|
|
session := uuid.New().String()
|
|
connCh, cancelRequest := controller.connRendezvous.Request(rendezvousCtx, session)
|
|
defer cancelRequest()
|
|
|
|
err := controller.workerNotifier.Notify(tunnelWaitCtx, vm.Worker, &rpc.WatchInstruction{
|
|
Action: &rpc.WatchInstruction_PortForwardAction{
|
|
PortForwardAction: &rpc.WatchInstruction_PortForward{
|
|
Session: session,
|
|
VmUid: vm.UID,
|
|
Port: 22,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
rendezvousCtxCancel()
|
|
|
|
controller.logger.Warnf("failed to request VM SSH port-forwarding from the worker %s: %v",
|
|
vm.Worker, err)
|
|
|
|
return nil, responder.Code(http.StatusServiceUnavailable)
|
|
}
|
|
|
|
select {
|
|
case rendezvousResp := <-connCh:
|
|
if rendezvousResp.ErrorMessage != "" {
|
|
rendezvousCtxCancel()
|
|
|
|
return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
|
|
"failed to establish SSH connection to the VM on the worker: %s", rendezvousResp.ErrorMessage))
|
|
}
|
|
|
|
if rendezvousResp.Result == nil {
|
|
rendezvousCtxCancel()
|
|
|
|
return nil, responder.Code(http.StatusServiceUnavailable)
|
|
}
|
|
|
|
return netconncancel.New(rendezvousResp.Result, rendezvousCtxCancel), nil
|
|
case <-tunnelWaitCtx.Done():
|
|
rendezvousCtxCancel()
|
|
|
|
if errors.Is(tunnelWaitCtx.Err(), context.DeadlineExceeded) {
|
|
return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
|
|
"timed out waiting for worker %s to establish SSH tunnel", vm.Worker))
|
|
}
|
|
|
|
return nil, responder.Error(tunnelWaitCtx.Err())
|
|
}
|
|
}
|
|
|
|
func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel net.Conn, ws *websocket.Conn, vm *v1.VM, cmd string, args []string) responder.Responder {
|
|
defer func() {
|
|
_ = tunnel.Close()
|
|
}()
|
|
|
|
tunnel, sshClient, err := controller.establishExecuteSSHClientWithRetry(ctx, tunnel, vm)
|
|
if err != nil {
|
|
if isNormalContextCancellation(err) {
|
|
return responder.Empty()
|
|
}
|
|
|
|
controller.closeExecuteWithFrameError(ws, nil,
|
|
fmt.Sprintf("SSH handshake with the VM failed: %v", err))
|
|
|
|
return responder.Empty()
|
|
}
|
|
defer func() {
|
|
_ = sshClient.Close()
|
|
}()
|
|
|
|
execution, err := startSSHExecution(sshClient, cmd, args)
|
|
if err != nil {
|
|
controller.closeExecuteWithFrameError(ws, nil, err.Error())
|
|
|
|
return responder.Empty()
|
|
}
|
|
defer func() {
|
|
_ = execution.session.Close()
|
|
}()
|
|
|
|
return controller.runExecuteSession(ctx, ws, execution)
|
|
}
|
|
|
|
func (controller *Controller) establishExecuteSSHClientWithRetry(
|
|
ctx context.Context,
|
|
tunnel net.Conn,
|
|
vm *v1.VM,
|
|
) (net.Conn, *ssh.Client, error) {
|
|
var lastErr error
|
|
|
|
for {
|
|
sshClient, err := newSSHClient(tunnel, vm)
|
|
if err == nil {
|
|
return tunnel, sshClient, nil
|
|
}
|
|
|
|
lastErr = err
|
|
_ = tunnel.Close()
|
|
|
|
controller.logger.Warnf("execute session: SSH handshake failed for VM %s on worker %s, retrying: %v",
|
|
vm.Name, vm.Worker, err)
|
|
|
|
if !waitExecuteSSHRetry(ctx) {
|
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
|
return nil, nil, ctxErr
|
|
}
|
|
|
|
return nil, nil, lastErr
|
|
}
|
|
|
|
nextTunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, vm)
|
|
if responderImpl != nil {
|
|
lastErr = errors.New("failed to establish SSH tunnel")
|
|
|
|
if !waitExecuteSSHRetry(ctx) {
|
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
|
return nil, nil, ctxErr
|
|
}
|
|
|
|
return nil, nil, lastErr
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
tunnel = nextTunnel
|
|
}
|
|
}
|
|
|
|
func waitExecuteSSHRetry(ctx context.Context) bool {
|
|
select {
|
|
case <-ctx.Done():
|
|
return false
|
|
case <-time.After(executeSessionSSHRetryDelay):
|
|
return true
|
|
}
|
|
}
|
|
|
|
func isNormalContextCancellation(err error) bool {
|
|
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
|
|
}
|
|
|
|
func newSSHClient(conn net.Conn, vm *v1.VM) (*ssh.Client, error) {
|
|
sshUser := vm.Username
|
|
sshPassword := vm.Password
|
|
if sshUser == "" && sshPassword == "" {
|
|
sshUser = "admin"
|
|
sshPassword = "admin"
|
|
}
|
|
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{
|
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
return nil
|
|
},
|
|
User: sshUser,
|
|
Auth: []ssh.AuthMethod{
|
|
ssh.Password(sshPassword),
|
|
},
|
|
Timeout: 10 * time.Second,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ssh.NewClient(sshConn, chans, reqs), nil
|
|
}
|
|
|
|
func startSSHExecution(client *ssh.Client, cmd string, args []string) (*sshExecution, error) {
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open SSH session: %v", err)
|
|
}
|
|
|
|
stdoutPipe, err := session.StdoutPipe()
|
|
if err != nil {
|
|
_ = session.Close()
|
|
|
|
return nil, fmt.Errorf("failed to get SSH stdout pipe: %v", err)
|
|
}
|
|
|
|
stderrPipe, err := session.StderrPipe()
|
|
if err != nil {
|
|
_ = session.Close()
|
|
|
|
return nil, fmt.Errorf("failed to get SSH stderr pipe: %v", err)
|
|
}
|
|
|
|
stdinPipe, err := session.StdinPipe()
|
|
if err != nil {
|
|
_ = session.Close()
|
|
|
|
return nil, fmt.Errorf("failed to get SSH stdin pipe: %v", err)
|
|
}
|
|
|
|
sshCommand := buildSSHCommand(cmd, args)
|
|
if err := session.Start(sshCommand); err != nil {
|
|
_ = session.Close()
|
|
|
|
return nil, fmt.Errorf("failed to start SSH command: %v", err)
|
|
}
|
|
|
|
return &sshExecution{
|
|
session: session,
|
|
stdout: stdoutPipe,
|
|
stderr: stderrPipe,
|
|
stdin: stdinPipe,
|
|
}, nil
|
|
}
|
|
|
|
func startExecuteSessionChannels(
|
|
sessionCtx context.Context,
|
|
decoder *json.Decoder,
|
|
execution *sshExecution,
|
|
) executeSessionChannels {
|
|
channels := executeSessionChannels{
|
|
outputFrameCh: make(chan execstream.Frame, 16),
|
|
outputDoneCh: make(chan struct{}, 2),
|
|
outputErrCh: make(chan error, 2),
|
|
stdinErrCh: make(chan error, 1),
|
|
exitCodeCh: make(chan int32, 1),
|
|
exitErrCh: make(chan error, 1),
|
|
}
|
|
|
|
go forwardSSHOutputFrames(sessionCtx, execution.stdout, execstream.FrameTypeStdout,
|
|
channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh)
|
|
go forwardSSHOutputFrames(sessionCtx, execution.stderr, execstream.FrameTypeStderr,
|
|
channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh)
|
|
go consumeClientInputFrames(decoder, execution.stdin, channels.stdinErrCh)
|
|
go waitForSSHExecutionExit(execution.session, channels.exitCodeCh, channels.exitErrCh)
|
|
|
|
return channels
|
|
}
|
|
|
|
func flushExecuteOutputFrames(encoder *json.Encoder, outputFrameCh <-chan execstream.Frame) error {
|
|
for len(outputFrameCh) > 0 {
|
|
frame := <-outputFrameCh
|
|
if err := execstream.WriteFrame(encoder, &frame); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func firstExecuteOutputError(outputErrCh <-chan error) error {
|
|
for len(outputErrCh) > 0 {
|
|
err := <-outputErrCh
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (controller *Controller) runExecuteSession(
|
|
ctx context.Context,
|
|
ws *websocket.Conn,
|
|
execution *sshExecution,
|
|
) responder.Responder {
|
|
sessionCtx, sessionCancel := context.WithCancel(ctx)
|
|
defer sessionCancel()
|
|
|
|
wsNetConn := websocket.NetConn(sessionCtx, ws, websocket.MessageText)
|
|
defer func() {
|
|
_ = wsNetConn.Close()
|
|
}()
|
|
|
|
encoder := execstream.NewEncoder(wsNetConn)
|
|
decoder := execstream.NewDecoder(wsNetConn)
|
|
|
|
channels := startExecuteSessionChannels(sessionCtx, decoder, execution)
|
|
|
|
pingTicker := time.NewTicker(controller.pingInterval)
|
|
defer pingTicker.Stop()
|
|
|
|
outputReadersDone := 0
|
|
exitSeen := false
|
|
var exitCode int32
|
|
|
|
for {
|
|
if exitSeen && outputReadersDone >= 2 {
|
|
if outputErr := firstExecuteOutputError(channels.outputErrCh); outputErr != nil {
|
|
controller.closeExecuteWithFrameError(ws, encoder,
|
|
fmt.Sprintf("failed while streaming command output: %v", outputErr))
|
|
|
|
return responder.Empty()
|
|
}
|
|
|
|
if err := flushExecuteOutputFrames(encoder, channels.outputFrameCh); err != nil {
|
|
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
|
|
"failed to stream execute output to the client", err)
|
|
}
|
|
|
|
if err := execstream.WriteFrame(encoder, &execstream.Frame{
|
|
Type: execstream.FrameTypeExit,
|
|
Exit: &execstream.Exit{Code: exitCode},
|
|
}); err != nil {
|
|
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
|
|
"failed to send execute exit status to the client", err)
|
|
}
|
|
|
|
if err := ws.Close(websocket.StatusNormalClosure,
|
|
fmt.Sprintf("command exited with code %d", exitCode)); err != nil {
|
|
controller.logger.Warnf("execute session: failed to close WebSocket connection: %v", err)
|
|
}
|
|
|
|
return responder.Empty()
|
|
}
|
|
|
|
select {
|
|
case frame := <-channels.outputFrameCh:
|
|
if err := execstream.WriteFrame(encoder, &frame); err != nil {
|
|
return controller.wsError(ws, websocket.StatusInternalError, "execute session",
|
|
"failed to stream execute output to the client", err)
|
|
}
|
|
case <-channels.outputDoneCh:
|
|
outputReadersDone++
|
|
case err := <-channels.outputErrCh:
|
|
if err == nil {
|
|
continue
|
|
}
|
|
|
|
controller.closeExecuteWithFrameError(ws, encoder,
|
|
fmt.Sprintf("failed while streaming command output: %v", err))
|
|
|
|
return responder.Empty()
|
|
case err := <-channels.stdinErrCh:
|
|
if err == nil || errors.Is(err, context.Canceled) {
|
|
channels.stdinErrCh = nil
|
|
continue
|
|
}
|
|
|
|
if errors.Is(err, io.EOF) {
|
|
return responder.Empty()
|
|
}
|
|
|
|
controller.closeExecuteWithFrameError(ws, encoder,
|
|
fmt.Sprintf("failed while reading command stdin stream: %v", err))
|
|
|
|
return responder.Empty()
|
|
case code := <-channels.exitCodeCh:
|
|
exitSeen = true
|
|
exitCode = code
|
|
case err := <-channels.exitErrCh:
|
|
controller.closeExecuteWithFrameError(ws, encoder,
|
|
fmt.Sprintf("failed while waiting for command completion: %v", err))
|
|
|
|
return responder.Empty()
|
|
case <-pingTicker.C:
|
|
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
|
|
if err := ws.Ping(pingCtx); err != nil {
|
|
controller.logger.Warnf("execute session: failed to ping the client, "+
|
|
"connection might time out: %v", err)
|
|
}
|
|
|
|
pingCtxCancel()
|
|
case <-ctx.Done():
|
|
if isNormalContextCancellation(ctx.Err()) {
|
|
return responder.Empty()
|
|
}
|
|
|
|
return responder.Error(ctx.Err())
|
|
}
|
|
}
|
|
}
|
|
|
|
func waitForSSHExecutionExit(sshSession *ssh.Session, exitCodeCh chan<- int32, exitErrCh chan<- error) {
|
|
if err := sshSession.Wait(); err != nil {
|
|
var exitError *ssh.ExitError
|
|
if errors.As(err, &exitError) {
|
|
exitCodeCh <- int32(exitError.ExitStatus())
|
|
|
|
return
|
|
}
|
|
|
|
exitErrCh <- err
|
|
|
|
return
|
|
}
|
|
|
|
exitCodeCh <- 0
|
|
}
|
|
|
|
func consumeClientInputFrames(
|
|
decoder *json.Decoder,
|
|
stdin io.WriteCloser,
|
|
errCh chan<- error,
|
|
) {
|
|
stdinClosed := false
|
|
|
|
for {
|
|
var frame execstream.Frame
|
|
|
|
if err := execstream.ReadFrame(decoder, &frame); err != nil {
|
|
if !stdinClosed {
|
|
if closeErr := stdin.Close(); closeErr != nil {
|
|
errCh <- closeErr
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
errCh <- err
|
|
|
|
return
|
|
}
|
|
|
|
switch frame.Type {
|
|
case execstream.FrameTypeStdin:
|
|
if len(frame.Data) == 0 {
|
|
if !stdinClosed {
|
|
if err := stdin.Close(); err != nil {
|
|
errCh <- err
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
errCh <- nil
|
|
|
|
return
|
|
}
|
|
|
|
if stdinClosed {
|
|
errCh <- errors.New("stdin is already closed")
|
|
|
|
return
|
|
}
|
|
|
|
if _, err := stdin.Write(frame.Data); err != nil {
|
|
errCh <- err
|
|
|
|
return
|
|
}
|
|
case execstream.FrameTypeResize:
|
|
// No-op for SSH backend without TTY support.
|
|
default:
|
|
errCh <- fmt.Errorf("unsupported frame type %q received from client", frame.Type)
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func forwardSSHOutputFrames(
|
|
ctx context.Context,
|
|
reader io.Reader,
|
|
frameType execstream.FrameType,
|
|
outputFrameCh chan<- execstream.Frame,
|
|
outputDoneCh chan<- struct{},
|
|
outputErrCh chan<- error,
|
|
) {
|
|
defer func() {
|
|
outputDoneCh <- struct{}{}
|
|
}()
|
|
|
|
buffer := make([]byte, 4096)
|
|
|
|
for {
|
|
n, err := reader.Read(buffer)
|
|
if n > 0 {
|
|
frame := execstream.Frame{
|
|
Type: frameType,
|
|
Data: append([]byte(nil), buffer[:n]...),
|
|
}
|
|
|
|
select {
|
|
case outputFrameCh <- frame:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
|
|
if errors.Is(err, io.EOF) {
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
select {
|
|
case outputErrCh <- err:
|
|
default:
|
|
}
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func buildSSHCommand(command string, args []string) string {
|
|
parts := make([]string, 0, 1+len(args))
|
|
parts = append(parts, shellQuoteArg(command))
|
|
for _, arg := range args {
|
|
parts = append(parts, shellQuoteArg(arg))
|
|
}
|
|
|
|
return strings.Join(parts, " ")
|
|
}
|
|
|
|
func shellQuoteArg(arg string) string {
|
|
if arg == "" {
|
|
return "''"
|
|
}
|
|
|
|
return "'" + strings.ReplaceAll(arg, "'", "'\\''") + "'"
|
|
}
|
|
|
|
func (controller *Controller) closeExecuteWithFrameError(
|
|
wsConn *websocket.Conn,
|
|
encoder *json.Encoder,
|
|
message string,
|
|
) {
|
|
if encoder != nil {
|
|
if err := execstream.WriteFrame(encoder, &execstream.Frame{
|
|
Type: execstream.FrameTypeError,
|
|
Error: message,
|
|
}); err != nil {
|
|
controller.logger.Warnf("execute session: failed to send error frame: %v", err)
|
|
}
|
|
}
|
|
|
|
if err := wsConn.Close(websocket.StatusInternalError, message); err != nil {
|
|
controller.logger.Warnf("execute session: failed to close WebSocket connection: %v", err)
|
|
}
|
|
}
|