498 lines
12 KiB
Go
498 lines
12 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/avast/retry-go/v5"
|
|
"github.com/cirruslabs/orchard/internal/controller/sshexec"
|
|
"github.com/cirruslabs/orchard/internal/execstream"
|
|
"github.com/cirruslabs/orchard/internal/responder"
|
|
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
"github.com/coder/websocket"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
|
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
|
return responder
|
|
}
|
|
|
|
// Retrieve and parse path and query parameters
|
|
name := ctx.Param("name")
|
|
sessionID := ctx.Query("session")
|
|
if sessionID == "" {
|
|
sessionID = ctx.Query("cmux_session_id")
|
|
}
|
|
|
|
command := ctx.Query("command")
|
|
if sessionID == "" && command == "" {
|
|
return responder.JSON(http.StatusBadRequest,
|
|
NewErrorResponse("\"command\" parameter cannot be empty"))
|
|
}
|
|
|
|
spec, runCommand, err := parseExecSessionSpec(ctx, command)
|
|
if err != nil {
|
|
return responder.JSON(http.StatusBadRequest, NewErrorResponse("%v", err))
|
|
}
|
|
|
|
waitRaw := ctx.DefaultQuery("wait", "10")
|
|
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
|
if err != nil {
|
|
return responder.Code(http.StatusBadRequest)
|
|
}
|
|
|
|
if sessionID != "" {
|
|
return controller.execVMReconnectable(ctx, name, sessionID, spec, runCommand, wait)
|
|
}
|
|
|
|
return controller.execVMLegacy(ctx, name, spec, runCommand, wait)
|
|
}
|
|
|
|
func (controller *Controller) execVMLegacy(
|
|
ctx *gin.Context,
|
|
name string,
|
|
spec execSessionSpec,
|
|
runCommand string,
|
|
wait uint64,
|
|
) responder.Responder {
|
|
// Look-up the VM
|
|
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
|
defer waitContextCancel()
|
|
|
|
vm, responderImpl := controller.waitForVM(waitContext, name)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
session, err := controller.newSSHExecSession(
|
|
ctx,
|
|
waitContext,
|
|
vm,
|
|
execSessionKey{vmName: name},
|
|
spec,
|
|
runCommand,
|
|
nil,
|
|
legacyExecSessionPolicy,
|
|
)
|
|
if err != nil {
|
|
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
|
|
}
|
|
|
|
// Upgrade HTTP request to a WebSocket connection
|
|
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
if err != nil {
|
|
session.closeIfUnused()
|
|
|
|
return responder.Error(err)
|
|
}
|
|
defer func() {
|
|
// Ensure that we always close the accepted WebSocket connection,
|
|
// otherwise resource leak is possible[1]
|
|
//
|
|
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
|
|
_ = wsConn.CloseNow()
|
|
}()
|
|
|
|
return controller.serveExecSession(ctx, wsConn, session)
|
|
}
|
|
|
|
func (controller *Controller) execVMReconnectable(
|
|
ctx *gin.Context,
|
|
name string,
|
|
sessionID string,
|
|
spec execSessionSpec,
|
|
runCommand string,
|
|
wait uint64,
|
|
) responder.Responder {
|
|
key := execSessionKey{
|
|
vmName: name,
|
|
sessionID: sessionID,
|
|
}
|
|
|
|
session, ok := controller.execSessions.get(key)
|
|
if ok {
|
|
if !session.specMatches(spec) {
|
|
return responder.JSON(http.StatusConflict,
|
|
NewErrorResponse("exec session %q is already running with different options", sessionID))
|
|
}
|
|
} else {
|
|
if spec.command == "" {
|
|
return responder.JSON(http.StatusNotFound,
|
|
NewErrorResponse("exec session %q does not exist", sessionID))
|
|
}
|
|
|
|
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
|
defer waitContextCancel()
|
|
|
|
vm, responderImpl := controller.waitForVM(waitContext, name)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
var err error
|
|
session, _, err = controller.execSessions.getOrCreate(waitContext, key, func() (*execSession, error) {
|
|
return controller.newSSHExecSession(
|
|
ctx,
|
|
waitContext,
|
|
vm,
|
|
key,
|
|
spec,
|
|
runCommand,
|
|
controller.execSessions,
|
|
reconnectableExecSessionPolicy,
|
|
)
|
|
})
|
|
if err != nil {
|
|
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
|
|
}
|
|
|
|
if !session.specMatches(spec) {
|
|
return responder.JSON(http.StatusConflict,
|
|
NewErrorResponse("exec session %q is already running with different options", sessionID))
|
|
}
|
|
}
|
|
|
|
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
if err != nil {
|
|
session.closeIfUnused()
|
|
|
|
return responder.Error(err)
|
|
}
|
|
defer func() {
|
|
_ = wsConn.CloseNow()
|
|
}()
|
|
|
|
return controller.serveExecSession(ctx, wsConn, session)
|
|
}
|
|
|
|
func (controller *Controller) newSSHExecSession(
|
|
_ *gin.Context,
|
|
waitContext context.Context,
|
|
vm *v1.VM,
|
|
key execSessionKey,
|
|
spec execSessionSpec,
|
|
runCommand string,
|
|
registry *execSessionRegistry,
|
|
policy execSessionPolicy,
|
|
) (*execSession, error) {
|
|
sessionContext, sessionContextCancel := context.WithCancel(context.Background())
|
|
|
|
type sshExecAttempt struct {
|
|
exec *sshexec.Exec
|
|
}
|
|
|
|
attempt, err := retry.NewWithData[sshExecAttempt](
|
|
retry.Context(waitContext),
|
|
retry.DelayType(retry.FixedDelay),
|
|
retry.Delay(time.Second),
|
|
retry.Attempts(0),
|
|
retry.LastErrorOnly(true),
|
|
).Do(func() (sshExecAttempt, error) {
|
|
exec, err := controller.execSSHClients.newExec(vm.UID, sshexec.Options{
|
|
Interactive: spec.interactive,
|
|
TTY: spec.tty,
|
|
Rows: spec.rows,
|
|
Cols: spec.cols,
|
|
}, func() (sshExecClient, error) {
|
|
portForwardConn, err := controller.portForwardConnection(
|
|
context.Background(),
|
|
waitContext,
|
|
vm.Worker,
|
|
vm.UID,
|
|
22,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client, err := sshexec.NewClient(portForwardConn, vm.SSHUsername(), vm.SSHPassword())
|
|
if err != nil {
|
|
_ = portForwardConn.Close()
|
|
|
|
return nil, fmt.Errorf("failed to establish SSH connection to a VM: %w", err)
|
|
}
|
|
|
|
return client, nil
|
|
})
|
|
if err != nil {
|
|
return sshExecAttempt{}, fmt.Errorf("failed to establish SSH connection to a VM: %w", err)
|
|
}
|
|
|
|
return sshExecAttempt{
|
|
exec: exec,
|
|
}, nil
|
|
})
|
|
if err != nil {
|
|
sessionContextCancel()
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return newExecSessionWithContextAndSpec(
|
|
sessionContext,
|
|
sessionContextCancel,
|
|
key,
|
|
spec,
|
|
runCommand,
|
|
attempt.exec,
|
|
nil,
|
|
registry,
|
|
controller.execSessionRetentionTTL,
|
|
policy,
|
|
), nil
|
|
}
|
|
|
|
func (controller *Controller) serveExecSession(
|
|
ctx *gin.Context,
|
|
wsConn *websocket.Conn,
|
|
session *execSession,
|
|
) responder.Responder {
|
|
subscriber, err := session.attach()
|
|
if err != nil {
|
|
_ = wsConn.Close(websocket.StatusNormalClosure, err.Error())
|
|
|
|
return responder.Empty()
|
|
}
|
|
defer session.detach(subscriber)
|
|
session.start()
|
|
|
|
readFramesErrCh := make(chan error, 1)
|
|
go func() {
|
|
readFramesErrCh <- controller.readExecSessionFrames(ctx, wsConn, session, subscriber)
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case readFramesErr := <-readFramesErrCh:
|
|
if readFramesErr != nil &&
|
|
!errors.Is(readFramesErr, errExecSessionDetached) &&
|
|
!errors.Is(readFramesErr, errExecSessionClosed) {
|
|
controller.logger.Warnf("failed to read and process exec frames from WebSocket: %v",
|
|
readFramesErr)
|
|
}
|
|
|
|
return responder.Empty()
|
|
case outgoingFrame, ok := <-subscriber.frames:
|
|
if !ok {
|
|
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
|
|
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
|
|
}
|
|
|
|
return responder.Empty()
|
|
}
|
|
|
|
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
|
|
controller.logger.Warnf("failed to write exec frame to the client: %v", err)
|
|
|
|
return responder.Empty()
|
|
}
|
|
case <-time.After(controller.pingInterval):
|
|
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
|
|
if err := wsConn.Ping(pingCtx); err != nil {
|
|
controller.logger.Warnf("exec: failed to ping the client, "+
|
|
"connection might time out: %v", err)
|
|
}
|
|
|
|
pingCtxCancel()
|
|
case <-ctx.Done():
|
|
controller.logger.Warnf("client disconnected prematurely")
|
|
|
|
return responder.Empty()
|
|
}
|
|
}
|
|
}
|
|
|
|
var (
|
|
errExecSessionDetached = errors.New("exec session detached")
|
|
errExecSessionClosed = errors.New("exec session closed")
|
|
)
|
|
|
|
func parseExecSessionSpec(ctx *gin.Context, command string) (execSessionSpec, string, error) {
|
|
interactive, err := parseExecInteractive(ctx)
|
|
if err != nil {
|
|
return execSessionSpec{}, "", err
|
|
}
|
|
|
|
tty, err := parseExecBool(ctx, "tty")
|
|
if err != nil {
|
|
return execSessionSpec{}, "", err
|
|
}
|
|
if tty {
|
|
interactive = true
|
|
}
|
|
|
|
rows, err := parseExecUint32(ctx.Query("rows"), "rows")
|
|
if err != nil {
|
|
return execSessionSpec{}, "", err
|
|
}
|
|
cols, err := parseExecUint32(ctx.Query("cols"), "cols")
|
|
if err != nil {
|
|
return execSessionSpec{}, "", err
|
|
}
|
|
if (rows == 0) != (cols == 0) {
|
|
return execSessionSpec{}, "", errors.New("\"rows\" and \"cols\" must be provided together")
|
|
}
|
|
|
|
spec := execSessionSpec{
|
|
command: command,
|
|
interactive: interactive,
|
|
tty: tty,
|
|
rows: rows,
|
|
cols: cols,
|
|
env: ctx.QueryMap("env"),
|
|
workdir: ctx.Query("workdir"),
|
|
}
|
|
|
|
runCommand, err := sshexec.CommandWithOptions(command, sshexec.Options{
|
|
Env: spec.env,
|
|
Workdir: spec.workdir,
|
|
})
|
|
if err != nil {
|
|
return execSessionSpec{}, "", err
|
|
}
|
|
|
|
return spec, runCommand, nil
|
|
}
|
|
|
|
func parseExecInteractive(ctx *gin.Context) (bool, error) {
|
|
interactive, err := parseExecBool(ctx, "interactive")
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
interactiveRaw, interactivePresent := ctx.GetQuery("interactive")
|
|
stdinRaw, stdinPresent := ctx.GetQuery("stdin")
|
|
if !stdinPresent {
|
|
return interactive, nil
|
|
}
|
|
|
|
stdin, err := strconv.ParseBool(stdinRaw)
|
|
if err != nil {
|
|
return false, errors.New("\"stdin\" parameter must be a boolean")
|
|
}
|
|
|
|
if interactivePresent {
|
|
parsedInteractive, _ := strconv.ParseBool(interactiveRaw)
|
|
if stdin != parsedInteractive {
|
|
return false, errors.New("\"interactive\" and \"stdin\" parameters cannot conflict")
|
|
}
|
|
}
|
|
|
|
if !interactivePresent {
|
|
interactive = stdin
|
|
}
|
|
|
|
return interactive, nil
|
|
}
|
|
|
|
func parseExecBool(ctx *gin.Context, name string) (bool, error) {
|
|
raw, present := ctx.GetQuery(name)
|
|
if !present {
|
|
return false, nil
|
|
}
|
|
|
|
value, err := strconv.ParseBool(raw)
|
|
if err != nil {
|
|
return false, fmt.Errorf("%q parameter must be a boolean", name)
|
|
}
|
|
|
|
return value, nil
|
|
}
|
|
|
|
func parseExecUint32(raw string, name string) (uint32, error) {
|
|
if raw == "" {
|
|
return 0, nil
|
|
}
|
|
|
|
value, err := strconv.ParseUint(raw, 10, 32)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("%q parameter must be an unsigned integer", name)
|
|
}
|
|
|
|
return uint32(value), nil
|
|
}
|
|
|
|
func (controller *Controller) readExecSessionFrames(
|
|
ctx context.Context,
|
|
wsConn *websocket.Conn,
|
|
session *execSession,
|
|
subscriber *execSessionSubscriber,
|
|
) error {
|
|
for {
|
|
var frame execstream.Frame
|
|
|
|
messageType, payloadBytes, err := wsConn.Read(ctx)
|
|
if err != nil {
|
|
var closeErr websocket.CloseError
|
|
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
|
|
return errExecSessionDetached
|
|
}
|
|
|
|
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
|
|
}
|
|
|
|
if messageType != websocket.MessageText {
|
|
continue
|
|
}
|
|
|
|
if err := json.Unmarshal(payloadBytes, &frame); err != nil {
|
|
return err
|
|
}
|
|
|
|
switch frame.Type {
|
|
case execstream.FrameTypeStdin:
|
|
if err := session.writeStdin(frame.Data); err != nil {
|
|
return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err)
|
|
}
|
|
case execstream.FrameTypeResize:
|
|
if frame.Terminal == nil {
|
|
return fmt.Errorf("failed to handle %q frame: terminal size is required", frame.Type)
|
|
}
|
|
|
|
if err := session.resize(frame.Terminal.Rows, frame.Terminal.Cols); err != nil {
|
|
return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err)
|
|
}
|
|
case execstream.FrameTypeHistory:
|
|
if !session.policy.replayEnabled {
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
|
|
session.sendHistory(subscriber, frame.Watermark)
|
|
case execstream.FrameTypeAck:
|
|
if !session.policy.replayEnabled {
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
|
|
session.ack(frame.Watermark)
|
|
case execstream.FrameTypeDetach:
|
|
if !session.policy.replayEnabled {
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
|
|
return errExecSessionDetached
|
|
case execstream.FrameTypeClose:
|
|
if !session.policy.replayEnabled {
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
|
|
session.close()
|
|
|
|
return errExecSessionClosed
|
|
default:
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
}
|
|
}
|