orchard/internal/controller/api_vms_exec.go

498 lines
12 KiB
Go

package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"time"
"github.com/avast/retry-go/v5"
"github.com/cirruslabs/orchard/internal/controller/sshexec"
"github.com/cirruslabs/orchard/internal/execstream"
"github.com/cirruslabs/orchard/internal/responder"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/coder/websocket"
"github.com/gin-gonic/gin"
)
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
v1.ServiceAccountRoleComputeConnect); responder != nil {
return responder
}
// Retrieve and parse path and query parameters
name := ctx.Param("name")
sessionID := ctx.Query("session")
if sessionID == "" {
sessionID = ctx.Query("cmux_session_id")
}
command := ctx.Query("command")
if sessionID == "" && command == "" {
return responder.JSON(http.StatusBadRequest,
NewErrorResponse("\"command\" parameter cannot be empty"))
}
spec, runCommand, err := parseExecSessionSpec(ctx, command)
if err != nil {
return responder.JSON(http.StatusBadRequest, NewErrorResponse("%v", err))
}
waitRaw := ctx.DefaultQuery("wait", "10")
wait, err := strconv.ParseUint(waitRaw, 10, 16)
if err != nil {
return responder.Code(http.StatusBadRequest)
}
if sessionID != "" {
return controller.execVMReconnectable(ctx, name, sessionID, spec, runCommand, wait)
}
return controller.execVMLegacy(ctx, name, spec, runCommand, wait)
}
func (controller *Controller) execVMLegacy(
ctx *gin.Context,
name string,
spec execSessionSpec,
runCommand string,
wait uint64,
) responder.Responder {
// Look-up the VM
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitContextCancel()
vm, responderImpl := controller.waitForVM(waitContext, name)
if responderImpl != nil {
return responderImpl
}
session, err := controller.newSSHExecSession(
ctx,
waitContext,
vm,
execSessionKey{vmName: name},
spec,
runCommand,
nil,
legacyExecSessionPolicy,
)
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
}
// Upgrade HTTP request to a WebSocket connection
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
session.closeIfUnused()
return responder.Error(err)
}
defer func() {
// Ensure that we always close the accepted WebSocket connection,
// otherwise resource leak is possible[1]
//
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
_ = wsConn.CloseNow()
}()
return controller.serveExecSession(ctx, wsConn, session)
}
func (controller *Controller) execVMReconnectable(
ctx *gin.Context,
name string,
sessionID string,
spec execSessionSpec,
runCommand string,
wait uint64,
) responder.Responder {
key := execSessionKey{
vmName: name,
sessionID: sessionID,
}
session, ok := controller.execSessions.get(key)
if ok {
if !session.specMatches(spec) {
return responder.JSON(http.StatusConflict,
NewErrorResponse("exec session %q is already running with different options", sessionID))
}
} else {
if spec.command == "" {
return responder.JSON(http.StatusNotFound,
NewErrorResponse("exec session %q does not exist", sessionID))
}
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitContextCancel()
vm, responderImpl := controller.waitForVM(waitContext, name)
if responderImpl != nil {
return responderImpl
}
var err error
session, _, err = controller.execSessions.getOrCreate(waitContext, key, func() (*execSession, error) {
return controller.newSSHExecSession(
ctx,
waitContext,
vm,
key,
spec,
runCommand,
controller.execSessions,
reconnectableExecSessionPolicy,
)
})
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
}
if !session.specMatches(spec) {
return responder.JSON(http.StatusConflict,
NewErrorResponse("exec session %q is already running with different options", sessionID))
}
}
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
session.closeIfUnused()
return responder.Error(err)
}
defer func() {
_ = wsConn.CloseNow()
}()
return controller.serveExecSession(ctx, wsConn, session)
}
func (controller *Controller) newSSHExecSession(
_ *gin.Context,
waitContext context.Context,
vm *v1.VM,
key execSessionKey,
spec execSessionSpec,
runCommand string,
registry *execSessionRegistry,
policy execSessionPolicy,
) (*execSession, error) {
sessionContext, sessionContextCancel := context.WithCancel(context.Background())
type sshExecAttempt struct {
exec *sshexec.Exec
}
attempt, err := retry.NewWithData[sshExecAttempt](
retry.Context(waitContext),
retry.DelayType(retry.FixedDelay),
retry.Delay(time.Second),
retry.Attempts(0),
retry.LastErrorOnly(true),
).Do(func() (sshExecAttempt, error) {
exec, err := controller.execSSHClients.newExec(vm.UID, sshexec.Options{
Interactive: spec.interactive,
TTY: spec.tty,
Rows: spec.rows,
Cols: spec.cols,
}, func() (sshExecClient, error) {
portForwardConn, err := controller.portForwardConnection(
context.Background(),
waitContext,
vm.Worker,
vm.UID,
22,
)
if err != nil {
return nil, err
}
client, err := sshexec.NewClient(portForwardConn, vm.SSHUsername(), vm.SSHPassword())
if err != nil {
_ = portForwardConn.Close()
return nil, fmt.Errorf("failed to establish SSH connection to a VM: %w", err)
}
return client, nil
})
if err != nil {
return sshExecAttempt{}, fmt.Errorf("failed to establish SSH connection to a VM: %w", err)
}
return sshExecAttempt{
exec: exec,
}, nil
})
if err != nil {
sessionContextCancel()
return nil, err
}
return newExecSessionWithContextAndSpec(
sessionContext,
sessionContextCancel,
key,
spec,
runCommand,
attempt.exec,
nil,
registry,
controller.execSessionRetentionTTL,
policy,
), nil
}
func (controller *Controller) serveExecSession(
ctx *gin.Context,
wsConn *websocket.Conn,
session *execSession,
) responder.Responder {
subscriber, err := session.attach()
if err != nil {
_ = wsConn.Close(websocket.StatusNormalClosure, err.Error())
return responder.Empty()
}
defer session.detach(subscriber)
session.start()
readFramesErrCh := make(chan error, 1)
go func() {
readFramesErrCh <- controller.readExecSessionFrames(ctx, wsConn, session, subscriber)
}()
for {
select {
case readFramesErr := <-readFramesErrCh:
if readFramesErr != nil &&
!errors.Is(readFramesErr, errExecSessionDetached) &&
!errors.Is(readFramesErr, errExecSessionClosed) {
controller.logger.Warnf("failed to read and process exec frames from WebSocket: %v",
readFramesErr)
}
return responder.Empty()
case outgoingFrame, ok := <-subscriber.frames:
if !ok {
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
}
return responder.Empty()
}
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
controller.logger.Warnf("failed to write exec frame to the client: %v", err)
return responder.Empty()
}
case <-time.After(controller.pingInterval):
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
if err := wsConn.Ping(pingCtx); err != nil {
controller.logger.Warnf("exec: failed to ping the client, "+
"connection might time out: %v", err)
}
pingCtxCancel()
case <-ctx.Done():
controller.logger.Warnf("client disconnected prematurely")
return responder.Empty()
}
}
}
var (
errExecSessionDetached = errors.New("exec session detached")
errExecSessionClosed = errors.New("exec session closed")
)
func parseExecSessionSpec(ctx *gin.Context, command string) (execSessionSpec, string, error) {
interactive, err := parseExecInteractive(ctx)
if err != nil {
return execSessionSpec{}, "", err
}
tty, err := parseExecBool(ctx, "tty")
if err != nil {
return execSessionSpec{}, "", err
}
if tty {
interactive = true
}
rows, err := parseExecUint32(ctx.Query("rows"), "rows")
if err != nil {
return execSessionSpec{}, "", err
}
cols, err := parseExecUint32(ctx.Query("cols"), "cols")
if err != nil {
return execSessionSpec{}, "", err
}
if (rows == 0) != (cols == 0) {
return execSessionSpec{}, "", errors.New("\"rows\" and \"cols\" must be provided together")
}
spec := execSessionSpec{
command: command,
interactive: interactive,
tty: tty,
rows: rows,
cols: cols,
env: ctx.QueryMap("env"),
workdir: ctx.Query("workdir"),
}
runCommand, err := sshexec.CommandWithOptions(command, sshexec.Options{
Env: spec.env,
Workdir: spec.workdir,
})
if err != nil {
return execSessionSpec{}, "", err
}
return spec, runCommand, nil
}
func parseExecInteractive(ctx *gin.Context) (bool, error) {
interactive, err := parseExecBool(ctx, "interactive")
if err != nil {
return false, err
}
interactiveRaw, interactivePresent := ctx.GetQuery("interactive")
stdinRaw, stdinPresent := ctx.GetQuery("stdin")
if !stdinPresent {
return interactive, nil
}
stdin, err := strconv.ParseBool(stdinRaw)
if err != nil {
return false, errors.New("\"stdin\" parameter must be a boolean")
}
if interactivePresent {
parsedInteractive, _ := strconv.ParseBool(interactiveRaw)
if stdin != parsedInteractive {
return false, errors.New("\"interactive\" and \"stdin\" parameters cannot conflict")
}
}
if !interactivePresent {
interactive = stdin
}
return interactive, nil
}
func parseExecBool(ctx *gin.Context, name string) (bool, error) {
raw, present := ctx.GetQuery(name)
if !present {
return false, nil
}
value, err := strconv.ParseBool(raw)
if err != nil {
return false, fmt.Errorf("%q parameter must be a boolean", name)
}
return value, nil
}
func parseExecUint32(raw string, name string) (uint32, error) {
if raw == "" {
return 0, nil
}
value, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
return 0, fmt.Errorf("%q parameter must be an unsigned integer", name)
}
return uint32(value), nil
}
func (controller *Controller) readExecSessionFrames(
ctx context.Context,
wsConn *websocket.Conn,
session *execSession,
subscriber *execSessionSubscriber,
) error {
for {
var frame execstream.Frame
messageType, payloadBytes, err := wsConn.Read(ctx)
if err != nil {
var closeErr websocket.CloseError
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
return errExecSessionDetached
}
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
}
if messageType != websocket.MessageText {
continue
}
if err := json.Unmarshal(payloadBytes, &frame); err != nil {
return err
}
switch frame.Type {
case execstream.FrameTypeStdin:
if err := session.writeStdin(frame.Data); err != nil {
return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err)
}
case execstream.FrameTypeResize:
if frame.Terminal == nil {
return fmt.Errorf("failed to handle %q frame: terminal size is required", frame.Type)
}
if err := session.resize(frame.Terminal.Rows, frame.Terminal.Cols); err != nil {
return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err)
}
case execstream.FrameTypeHistory:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
session.sendHistory(subscriber, frame.Watermark)
case execstream.FrameTypeAck:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
session.ack(frame.Watermark)
case execstream.FrameTypeDetach:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
return errExecSessionDetached
case execstream.FrameTypeClose:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
session.close()
return errExecSessionClosed
default:
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
}
}