orchard/internal/controller/api_vms_exec.go

354 lines
8.9 KiB
Go

package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"time"
"github.com/avast/retry-go/v5"
"github.com/cirruslabs/orchard/internal/controller/sshexec"
"github.com/cirruslabs/orchard/internal/execstream"
"github.com/cirruslabs/orchard/internal/responder"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/coder/websocket"
"github.com/gin-gonic/gin"
)
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
v1.ServiceAccountRoleComputeConnect); responder != nil {
return responder
}
// Retrieve and parse path and query parameters
name := ctx.Param("name")
sessionID := ctx.Query("session")
if sessionID == "" {
sessionID = ctx.Query("cmux_session_id")
}
command := ctx.Query("command")
if sessionID == "" && command == "" {
return responder.JSON(http.StatusBadRequest,
NewErrorResponse("\"command\" parameter cannot be empty"))
}
stdin := ctx.Query("stdin") == "true"
waitRaw := ctx.DefaultQuery("wait", "10")
wait, err := strconv.ParseUint(waitRaw, 10, 16)
if err != nil {
return responder.Code(http.StatusBadRequest)
}
if sessionID != "" {
return controller.execVMReconnectable(ctx, name, sessionID, command, stdin, wait)
}
return controller.execVMLegacy(ctx, name, command, stdin, wait)
}
func (controller *Controller) execVMLegacy(
ctx *gin.Context,
name string,
command string,
stdin bool,
wait uint64,
) responder.Responder {
// Look-up the VM
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitContextCancel()
vm, responderImpl := controller.waitForVM(waitContext, name)
if responderImpl != nil {
return responderImpl
}
session, err := controller.newSSHExecSession(
ctx,
waitContext,
vm,
execSessionKey{vmName: name},
command,
stdin,
nil,
legacyExecSessionPolicy,
)
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
}
// Upgrade HTTP request to a WebSocket connection
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
session.closeIfUnused()
return responder.Error(err)
}
defer func() {
// Ensure that we always close the accepted WebSocket connection,
// otherwise resource leak is possible[1]
//
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
_ = wsConn.CloseNow()
}()
return controller.serveExecSession(ctx, wsConn, session)
}
func (controller *Controller) execVMReconnectable(
ctx *gin.Context,
name string,
sessionID string,
command string,
stdin bool,
wait uint64,
) responder.Responder {
key := execSessionKey{
vmName: name,
sessionID: sessionID,
}
session, ok := controller.execSessions.get(key)
if ok {
if !session.commandMatches(command) {
return responder.JSON(http.StatusConflict,
NewErrorResponse("exec session %q is already running a different command", sessionID))
}
} else {
if command == "" {
return responder.JSON(http.StatusNotFound,
NewErrorResponse("exec session %q does not exist", sessionID))
}
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitContextCancel()
vm, responderImpl := controller.waitForVM(waitContext, name)
if responderImpl != nil {
return responderImpl
}
var err error
session, _, err = controller.execSessions.getOrCreate(waitContext, key, func() (*execSession, error) {
return controller.newSSHExecSession(
ctx,
waitContext,
vm,
key,
command,
stdin,
controller.execSessions,
reconnectableExecSessionPolicy,
)
})
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
}
if !session.commandMatches(command) {
return responder.JSON(http.StatusConflict,
NewErrorResponse("exec session %q is already running a different command", sessionID))
}
}
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
session.closeIfUnused()
return responder.Error(err)
}
defer func() {
_ = wsConn.CloseNow()
}()
return controller.serveExecSession(ctx, wsConn, session)
}
func (controller *Controller) newSSHExecSession(
_ *gin.Context,
waitContext context.Context,
vm *v1.VM,
key execSessionKey,
command string,
stdin bool,
registry *execSessionRegistry,
policy execSessionPolicy,
) (*execSession, error) {
sessionContext, sessionContextCancel := context.WithCancel(context.Background())
portForwardConn, err := retry.NewWithData[net.Conn](
retry.Context(waitContext),
retry.DelayType(retry.FixedDelay),
retry.Delay(time.Second),
retry.Attempts(0),
retry.LastErrorOnly(true),
).Do(func() (net.Conn, error) {
return controller.portForwardConnection(sessionContext, waitContext, vm.Worker, vm.UID, 22)
})
if err != nil {
sessionContextCancel()
return nil, err
}
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
if err != nil {
sessionContextCancel()
_ = portForwardConn.Close()
return nil, fmt.Errorf("failed to establish SSH connection to a VM: %w", err)
}
return newExecSessionWithContext(
sessionContext,
sessionContextCancel,
key,
command,
exec,
portForwardConn,
registry,
controller.execSessionExitTTL,
policy,
), nil
}
func (controller *Controller) serveExecSession(
ctx *gin.Context,
wsConn *websocket.Conn,
session *execSession,
) responder.Responder {
subscriber, err := session.attach()
if err != nil {
_ = wsConn.Close(websocket.StatusNormalClosure, err.Error())
return responder.Empty()
}
defer session.detach(subscriber)
session.start()
readFramesErrCh := make(chan error, 1)
go func() {
readFramesErrCh <- controller.readExecSessionFrames(ctx, wsConn, session, subscriber)
}()
for {
select {
case readFramesErr := <-readFramesErrCh:
if readFramesErr != nil &&
!errors.Is(readFramesErr, errExecSessionDetached) &&
!errors.Is(readFramesErr, errExecSessionClosed) {
controller.logger.Warnf("failed to read and process exec frames from WebSocket: %v",
readFramesErr)
}
return responder.Empty()
case outgoingFrame, ok := <-subscriber.frames:
if !ok {
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
}
return responder.Empty()
}
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
controller.logger.Warnf("failed to write exec frame to the client: %v", err)
return responder.Empty()
}
case <-time.After(controller.pingInterval):
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
if err := wsConn.Ping(pingCtx); err != nil {
controller.logger.Warnf("exec: failed to ping the client, "+
"connection might time out: %v", err)
}
pingCtxCancel()
case <-ctx.Done():
controller.logger.Warnf("client disconnected prematurely")
return responder.Empty()
}
}
}
var (
errExecSessionDetached = errors.New("exec session detached")
errExecSessionClosed = errors.New("exec session closed")
)
func (controller *Controller) readExecSessionFrames(
ctx context.Context,
wsConn *websocket.Conn,
session *execSession,
subscriber *execSessionSubscriber,
) error {
for {
var frame execstream.Frame
messageType, payloadBytes, err := wsConn.Read(ctx)
if err != nil {
var closeErr websocket.CloseError
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
return errExecSessionDetached
}
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
}
if messageType != websocket.MessageText {
continue
}
if err := json.Unmarshal(payloadBytes, &frame); err != nil {
return err
}
switch frame.Type {
case execstream.FrameTypeStdin:
if err := session.writeStdin(frame.Data); err != nil {
return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err)
}
case execstream.FrameTypeHistory:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
session.sendHistory(subscriber, frame.Watermark)
case execstream.FrameTypeAck:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
session.ack(frame.Watermark)
case execstream.FrameTypeDetach:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
return errExecSessionDetached
case execstream.FrameTypeClose:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
session.close()
return errExecSessionClosed
default:
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
}
}