package controller import ( "context" "encoding/json" "errors" "fmt" "io" "net" "net/http" "strconv" "strings" "time" "github.com/cirruslabs/orchard/internal/execstream" "github.com/cirruslabs/orchard/internal/netconncancel" "github.com/cirruslabs/orchard/internal/responder" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/cirruslabs/orchard/rpc" "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/google/uuid" "golang.org/x/crypto/ssh" ) const ( executeSessionRendezvousTimeout = 15 * time.Second executeSessionSSHRetryDelay = 1 * time.Second ) type sshExecution struct { session *ssh.Session stdout io.Reader stderr io.Reader stdin io.WriteCloser } type executeVMRequest struct { name string command string args []string wait time.Duration } type executeSessionChannels struct { outputFrameCh chan execstream.Frame outputDoneCh chan struct{} outputErrCh chan error stdinErrCh chan error exitCodeCh chan int32 exitErrCh chan error } func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder { if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite, v1.ServiceAccountRoleComputeConnect); responder != nil { return responder } request, responderImpl := parseExecuteVMRequest(ctx) if responderImpl != nil { return responderImpl } waitCtx, waitCancel := context.WithTimeout(ctx, request.wait) defer waitCancel() vm, responderImpl := controller.waitForVM(waitCtx, request.name) if responderImpl != nil { return responderImpl } tunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, vm) if responderImpl != nil { return responderImpl } wsConn, err := acceptExecuteWebSocket(ctx) if err != nil { _ = tunnel.Close() return responder.Error(err) } defer func() { _ = wsConn.CloseNow() }() return controller.executeVMViaSSHTunnel(ctx, tunnel, wsConn, vm, request.command, request.args) } func parseExecuteVMRequest(ctx *gin.Context) (*executeVMRequest, responder.Responder) { command := ctx.Query("command") if command == "" { return nil, responder.Code(http.StatusBadRequest) } waitRaw := ctx.DefaultQuery("wait", "10") wait, err := strconv.ParseUint(waitRaw, 10, 16) if err != nil { return nil, responder.Code(http.StatusBadRequest) } return &executeVMRequest{ name: ctx.Param("name"), command: command, args: ctx.QueryArray("arg"), wait: time.Duration(wait) * time.Second, }, nil } func acceptExecuteWebSocket(ctx *gin.Context) (*websocket.Conn, error) { return websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, }) } func (controller *Controller) establishExecuteSSHTunnel( ctx context.Context, vm *v1.VM, ) (net.Conn, responder.Responder) { tunnelWaitCtx, tunnelWaitCtxCancel := context.WithTimeout(ctx, executeSessionRendezvousTimeout) defer tunnelWaitCtxCancel() rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx) session := uuid.New().String() connCh, cancelRequest := controller.connRendezvous.Request(rendezvousCtx, session) defer cancelRequest() err := controller.workerNotifier.Notify(tunnelWaitCtx, vm.Worker, &rpc.WatchInstruction{ Action: &rpc.WatchInstruction_PortForwardAction{ PortForwardAction: &rpc.WatchInstruction_PortForward{ Session: session, VmUid: vm.UID, Port: 22, }, }, }) if err != nil { rendezvousCtxCancel() controller.logger.Warnf("failed to request VM SSH port-forwarding from the worker %s: %v", vm.Worker, err) return nil, responder.Code(http.StatusServiceUnavailable) } select { case rendezvousResp := <-connCh: if rendezvousResp.ErrorMessage != "" { rendezvousCtxCancel() return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse( "failed to establish SSH connection to the VM on the worker: %s", rendezvousResp.ErrorMessage)) } if rendezvousResp.Result == nil { rendezvousCtxCancel() return nil, responder.Code(http.StatusServiceUnavailable) } return netconncancel.New(rendezvousResp.Result, rendezvousCtxCancel), nil case <-tunnelWaitCtx.Done(): rendezvousCtxCancel() if errors.Is(tunnelWaitCtx.Err(), context.DeadlineExceeded) { return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse( "timed out waiting for worker %s to establish SSH tunnel", vm.Worker)) } return nil, responder.Error(tunnelWaitCtx.Err()) } } func (controller *Controller) executeVMViaSSHTunnel(ctx context.Context, tunnel net.Conn, ws *websocket.Conn, vm *v1.VM, cmd string, args []string) responder.Responder { defer func() { _ = tunnel.Close() }() tunnel, sshClient, err := controller.establishExecuteSSHClientWithRetry(ctx, tunnel, vm) if err != nil { if isNormalContextCancellation(err) { return responder.Empty() } controller.closeExecuteWithFrameError(ws, nil, fmt.Sprintf("SSH handshake with the VM failed: %v", err)) return responder.Empty() } defer func() { _ = sshClient.Close() }() execution, err := startSSHExecution(sshClient, cmd, args) if err != nil { controller.closeExecuteWithFrameError(ws, nil, err.Error()) return responder.Empty() } defer func() { _ = execution.session.Close() }() return controller.runExecuteSession(ctx, ws, execution) } func (controller *Controller) establishExecuteSSHClientWithRetry( ctx context.Context, tunnel net.Conn, vm *v1.VM, ) (net.Conn, *ssh.Client, error) { var lastErr error for { sshClient, err := newSSHClient(tunnel, vm) if err == nil { return tunnel, sshClient, nil } lastErr = err _ = tunnel.Close() controller.logger.Warnf("execute session: SSH handshake failed for VM %s on worker %s, retrying: %v", vm.Name, vm.Worker, err) if !waitExecuteSSHRetry(ctx) { if ctxErr := ctx.Err(); ctxErr != nil { return nil, nil, ctxErr } return nil, nil, lastErr } nextTunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, vm) if responderImpl != nil { lastErr = errors.New("failed to establish SSH tunnel") if !waitExecuteSSHRetry(ctx) { if ctxErr := ctx.Err(); ctxErr != nil { return nil, nil, ctxErr } return nil, nil, lastErr } continue } tunnel = nextTunnel } } func waitExecuteSSHRetry(ctx context.Context) bool { select { case <-ctx.Done(): return false case <-time.After(executeSessionSSHRetryDelay): return true } } func isNormalContextCancellation(err error) bool { return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) } func newSSHClient(conn net.Conn, vm *v1.VM) (*ssh.Client, error) { sshUser := vm.Username sshPassword := vm.Password if sshUser == "" && sshPassword == "" { sshUser = "admin" sshPassword = "admin" } sshConn, chans, reqs, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, User: sshUser, Auth: []ssh.AuthMethod{ ssh.Password(sshPassword), }, Timeout: 10 * time.Second, }) if err != nil { return nil, err } return ssh.NewClient(sshConn, chans, reqs), nil } func startSSHExecution(client *ssh.Client, cmd string, args []string) (*sshExecution, error) { session, err := client.NewSession() if err != nil { return nil, fmt.Errorf("failed to open SSH session: %v", err) } stdoutPipe, err := session.StdoutPipe() if err != nil { _ = session.Close() return nil, fmt.Errorf("failed to get SSH stdout pipe: %v", err) } stderrPipe, err := session.StderrPipe() if err != nil { _ = session.Close() return nil, fmt.Errorf("failed to get SSH stderr pipe: %v", err) } stdinPipe, err := session.StdinPipe() if err != nil { _ = session.Close() return nil, fmt.Errorf("failed to get SSH stdin pipe: %v", err) } sshCommand := buildSSHCommand(cmd, args) if err := session.Start(sshCommand); err != nil { _ = session.Close() return nil, fmt.Errorf("failed to start SSH command: %v", err) } return &sshExecution{ session: session, stdout: stdoutPipe, stderr: stderrPipe, stdin: stdinPipe, }, nil } func startExecuteSessionChannels( sessionCtx context.Context, decoder *json.Decoder, execution *sshExecution, ) executeSessionChannels { channels := executeSessionChannels{ outputFrameCh: make(chan execstream.Frame, 16), outputDoneCh: make(chan struct{}, 2), outputErrCh: make(chan error, 2), stdinErrCh: make(chan error, 1), exitCodeCh: make(chan int32, 1), exitErrCh: make(chan error, 1), } go forwardSSHOutputFrames(sessionCtx, execution.stdout, execstream.FrameTypeStdout, channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh) go forwardSSHOutputFrames(sessionCtx, execution.stderr, execstream.FrameTypeStderr, channels.outputFrameCh, channels.outputDoneCh, channels.outputErrCh) go consumeClientInputFrames(decoder, execution.stdin, channels.stdinErrCh) go waitForSSHExecutionExit(execution.session, channels.exitCodeCh, channels.exitErrCh) return channels } func flushExecuteOutputFrames(encoder *json.Encoder, outputFrameCh <-chan execstream.Frame) error { for len(outputFrameCh) > 0 { frame := <-outputFrameCh if err := execstream.WriteFrame(encoder, &frame); err != nil { return err } } return nil } func firstExecuteOutputError(outputErrCh <-chan error) error { for len(outputErrCh) > 0 { err := <-outputErrCh if err != nil { return err } } return nil } func (controller *Controller) runExecuteSession( ctx context.Context, ws *websocket.Conn, execution *sshExecution, ) responder.Responder { sessionCtx, sessionCancel := context.WithCancel(ctx) defer sessionCancel() wsNetConn := websocket.NetConn(sessionCtx, ws, websocket.MessageText) defer func() { _ = wsNetConn.Close() }() encoder := execstream.NewEncoder(wsNetConn) decoder := execstream.NewDecoder(wsNetConn) channels := startExecuteSessionChannels(sessionCtx, decoder, execution) pingTicker := time.NewTicker(controller.pingInterval) defer pingTicker.Stop() outputReadersDone := 0 exitSeen := false var exitCode int32 for { if exitSeen && outputReadersDone >= 2 { if outputErr := firstExecuteOutputError(channels.outputErrCh); outputErr != nil { controller.closeExecuteWithFrameError(ws, encoder, fmt.Sprintf("failed while streaming command output: %v", outputErr)) return responder.Empty() } if err := flushExecuteOutputFrames(encoder, channels.outputFrameCh); err != nil { return controller.wsError(ws, websocket.StatusInternalError, "execute session", "failed to stream execute output to the client", err) } if err := execstream.WriteFrame(encoder, &execstream.Frame{ Type: execstream.FrameTypeExit, Exit: &execstream.Exit{Code: exitCode}, }); err != nil { return controller.wsError(ws, websocket.StatusInternalError, "execute session", "failed to send execute exit status to the client", err) } if err := ws.Close(websocket.StatusNormalClosure, fmt.Sprintf("command exited with code %d", exitCode)); err != nil { controller.logger.Warnf("execute session: failed to close WebSocket connection: %v", err) } return responder.Empty() } select { case frame := <-channels.outputFrameCh: if err := execstream.WriteFrame(encoder, &frame); err != nil { return controller.wsError(ws, websocket.StatusInternalError, "execute session", "failed to stream execute output to the client", err) } case <-channels.outputDoneCh: outputReadersDone++ case err := <-channels.outputErrCh: if err == nil { continue } controller.closeExecuteWithFrameError(ws, encoder, fmt.Sprintf("failed while streaming command output: %v", err)) return responder.Empty() case err := <-channels.stdinErrCh: if err == nil || errors.Is(err, context.Canceled) { channels.stdinErrCh = nil continue } if errors.Is(err, io.EOF) { return responder.Empty() } controller.closeExecuteWithFrameError(ws, encoder, fmt.Sprintf("failed while reading command stdin stream: %v", err)) return responder.Empty() case code := <-channels.exitCodeCh: exitSeen = true exitCode = code case err := <-channels.exitErrCh: controller.closeExecuteWithFrameError(ws, encoder, fmt.Sprintf("failed while waiting for command completion: %v", err)) return responder.Empty() case <-pingTicker.C: pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second) if err := ws.Ping(pingCtx); err != nil { controller.logger.Warnf("execute session: failed to ping the client, "+ "connection might time out: %v", err) } pingCtxCancel() case <-ctx.Done(): if isNormalContextCancellation(ctx.Err()) { return responder.Empty() } return responder.Error(ctx.Err()) } } } func waitForSSHExecutionExit(sshSession *ssh.Session, exitCodeCh chan<- int32, exitErrCh chan<- error) { if err := sshSession.Wait(); err != nil { var exitError *ssh.ExitError if errors.As(err, &exitError) { exitCodeCh <- int32(exitError.ExitStatus()) return } exitErrCh <- err return } exitCodeCh <- 0 } func consumeClientInputFrames( decoder *json.Decoder, stdin io.WriteCloser, errCh chan<- error, ) { stdinClosed := false for { var frame execstream.Frame if err := execstream.ReadFrame(decoder, &frame); err != nil { if !stdinClosed { if closeErr := stdin.Close(); closeErr != nil { errCh <- closeErr return } } errCh <- err return } switch frame.Type { case execstream.FrameTypeStdin: if len(frame.Data) == 0 { if !stdinClosed { if err := stdin.Close(); err != nil { errCh <- err return } } errCh <- nil return } if stdinClosed { errCh <- errors.New("stdin is already closed") return } if _, err := stdin.Write(frame.Data); err != nil { errCh <- err return } case execstream.FrameTypeResize: // No-op for SSH backend without TTY support. default: errCh <- fmt.Errorf("unsupported frame type %q received from client", frame.Type) return } } } func forwardSSHOutputFrames( ctx context.Context, reader io.Reader, frameType execstream.FrameType, outputFrameCh chan<- execstream.Frame, outputDoneCh chan<- struct{}, outputErrCh chan<- error, ) { defer func() { outputDoneCh <- struct{}{} }() buffer := make([]byte, 4096) for { n, err := reader.Read(buffer) if n > 0 { frame := execstream.Frame{ Type: frameType, Data: append([]byte(nil), buffer[:n]...), } select { case outputFrameCh <- frame: case <-ctx.Done(): return } } if errors.Is(err, io.EOF) { return } if err != nil { select { case outputErrCh <- err: default: } return } } } func buildSSHCommand(command string, args []string) string { parts := make([]string, 0, 1+len(args)) parts = append(parts, shellQuoteArg(command)) for _, arg := range args { parts = append(parts, shellQuoteArg(arg)) } return strings.Join(parts, " ") } func shellQuoteArg(arg string) string { if arg == "" { return "''" } return "'" + strings.ReplaceAll(arg, "'", "'\\''") + "'" } func (controller *Controller) closeExecuteWithFrameError( wsConn *websocket.Conn, encoder *json.Encoder, message string, ) { if encoder != nil { if err := execstream.WriteFrame(encoder, &execstream.Frame{ Type: execstream.FrameTypeError, Error: message, }); err != nil { controller.logger.Warnf("execute session: failed to send error frame: %v", err) } } if err := wsConn.Close(websocket.StatusInternalError, message); err != nil { controller.logger.Warnf("execute session: failed to close WebSocket connection: %v", err) } }