354 lines
8.9 KiB
Go
354 lines
8.9 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/avast/retry-go/v5"
|
|
"github.com/cirruslabs/orchard/internal/controller/sshexec"
|
|
"github.com/cirruslabs/orchard/internal/execstream"
|
|
"github.com/cirruslabs/orchard/internal/responder"
|
|
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
"github.com/coder/websocket"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
|
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
|
return responder
|
|
}
|
|
|
|
// Retrieve and parse path and query parameters
|
|
name := ctx.Param("name")
|
|
sessionID := ctx.Query("session")
|
|
if sessionID == "" {
|
|
sessionID = ctx.Query("cmux_session_id")
|
|
}
|
|
|
|
command := ctx.Query("command")
|
|
if sessionID == "" && command == "" {
|
|
return responder.JSON(http.StatusBadRequest,
|
|
NewErrorResponse("\"command\" parameter cannot be empty"))
|
|
}
|
|
|
|
stdin := ctx.Query("stdin") == "true"
|
|
|
|
waitRaw := ctx.DefaultQuery("wait", "10")
|
|
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
|
if err != nil {
|
|
return responder.Code(http.StatusBadRequest)
|
|
}
|
|
|
|
if sessionID != "" {
|
|
return controller.execVMReconnectable(ctx, name, sessionID, command, stdin, wait)
|
|
}
|
|
|
|
return controller.execVMLegacy(ctx, name, command, stdin, wait)
|
|
}
|
|
|
|
func (controller *Controller) execVMLegacy(
|
|
ctx *gin.Context,
|
|
name string,
|
|
command string,
|
|
stdin bool,
|
|
wait uint64,
|
|
) responder.Responder {
|
|
// Look-up the VM
|
|
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
|
defer waitContextCancel()
|
|
|
|
vm, responderImpl := controller.waitForVM(waitContext, name)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
session, err := controller.newSSHExecSession(
|
|
ctx,
|
|
waitContext,
|
|
vm,
|
|
execSessionKey{vmName: name},
|
|
command,
|
|
stdin,
|
|
nil,
|
|
legacyExecSessionPolicy,
|
|
)
|
|
if err != nil {
|
|
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
|
|
}
|
|
|
|
// Upgrade HTTP request to a WebSocket connection
|
|
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
if err != nil {
|
|
session.closeIfUnused()
|
|
|
|
return responder.Error(err)
|
|
}
|
|
defer func() {
|
|
// Ensure that we always close the accepted WebSocket connection,
|
|
// otherwise resource leak is possible[1]
|
|
//
|
|
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
|
|
_ = wsConn.CloseNow()
|
|
}()
|
|
|
|
return controller.serveExecSession(ctx, wsConn, session)
|
|
}
|
|
|
|
func (controller *Controller) execVMReconnectable(
|
|
ctx *gin.Context,
|
|
name string,
|
|
sessionID string,
|
|
command string,
|
|
stdin bool,
|
|
wait uint64,
|
|
) responder.Responder {
|
|
key := execSessionKey{
|
|
vmName: name,
|
|
sessionID: sessionID,
|
|
}
|
|
|
|
session, ok := controller.execSessions.get(key)
|
|
if ok {
|
|
if !session.commandMatches(command) {
|
|
return responder.JSON(http.StatusConflict,
|
|
NewErrorResponse("exec session %q is already running a different command", sessionID))
|
|
}
|
|
} else {
|
|
if command == "" {
|
|
return responder.JSON(http.StatusNotFound,
|
|
NewErrorResponse("exec session %q does not exist", sessionID))
|
|
}
|
|
|
|
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
|
defer waitContextCancel()
|
|
|
|
vm, responderImpl := controller.waitForVM(waitContext, name)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
var err error
|
|
session, _, err = controller.execSessions.getOrCreate(waitContext, key, func() (*execSession, error) {
|
|
return controller.newSSHExecSession(
|
|
ctx,
|
|
waitContext,
|
|
vm,
|
|
key,
|
|
command,
|
|
stdin,
|
|
controller.execSessions,
|
|
reconnectableExecSessionPolicy,
|
|
)
|
|
})
|
|
if err != nil {
|
|
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
|
|
}
|
|
|
|
if !session.commandMatches(command) {
|
|
return responder.JSON(http.StatusConflict,
|
|
NewErrorResponse("exec session %q is already running a different command", sessionID))
|
|
}
|
|
}
|
|
|
|
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
if err != nil {
|
|
session.closeIfUnused()
|
|
|
|
return responder.Error(err)
|
|
}
|
|
defer func() {
|
|
_ = wsConn.CloseNow()
|
|
}()
|
|
|
|
return controller.serveExecSession(ctx, wsConn, session)
|
|
}
|
|
|
|
func (controller *Controller) newSSHExecSession(
|
|
_ *gin.Context,
|
|
waitContext context.Context,
|
|
vm *v1.VM,
|
|
key execSessionKey,
|
|
command string,
|
|
stdin bool,
|
|
registry *execSessionRegistry,
|
|
policy execSessionPolicy,
|
|
) (*execSession, error) {
|
|
sessionContext, sessionContextCancel := context.WithCancel(context.Background())
|
|
|
|
portForwardConn, err := retry.NewWithData[net.Conn](
|
|
retry.Context(waitContext),
|
|
retry.DelayType(retry.FixedDelay),
|
|
retry.Delay(time.Second),
|
|
retry.Attempts(0),
|
|
retry.LastErrorOnly(true),
|
|
).Do(func() (net.Conn, error) {
|
|
return controller.portForwardConnection(sessionContext, waitContext, vm.Worker, vm.UID, 22)
|
|
})
|
|
if err != nil {
|
|
sessionContextCancel()
|
|
|
|
return nil, err
|
|
}
|
|
|
|
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
|
|
if err != nil {
|
|
sessionContextCancel()
|
|
_ = portForwardConn.Close()
|
|
|
|
return nil, fmt.Errorf("failed to establish SSH connection to a VM: %w", err)
|
|
}
|
|
|
|
return newExecSessionWithContext(
|
|
sessionContext,
|
|
sessionContextCancel,
|
|
key,
|
|
command,
|
|
exec,
|
|
portForwardConn,
|
|
registry,
|
|
controller.execSessionExitTTL,
|
|
policy,
|
|
), nil
|
|
}
|
|
|
|
func (controller *Controller) serveExecSession(
|
|
ctx *gin.Context,
|
|
wsConn *websocket.Conn,
|
|
session *execSession,
|
|
) responder.Responder {
|
|
subscriber, err := session.attach()
|
|
if err != nil {
|
|
_ = wsConn.Close(websocket.StatusNormalClosure, err.Error())
|
|
|
|
return responder.Empty()
|
|
}
|
|
defer session.detach(subscriber)
|
|
session.start()
|
|
|
|
readFramesErrCh := make(chan error, 1)
|
|
go func() {
|
|
readFramesErrCh <- controller.readExecSessionFrames(ctx, wsConn, session, subscriber)
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case readFramesErr := <-readFramesErrCh:
|
|
if readFramesErr != nil &&
|
|
!errors.Is(readFramesErr, errExecSessionDetached) &&
|
|
!errors.Is(readFramesErr, errExecSessionClosed) {
|
|
controller.logger.Warnf("failed to read and process exec frames from WebSocket: %v",
|
|
readFramesErr)
|
|
}
|
|
|
|
return responder.Empty()
|
|
case outgoingFrame, ok := <-subscriber.frames:
|
|
if !ok {
|
|
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
|
|
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
|
|
}
|
|
|
|
return responder.Empty()
|
|
}
|
|
|
|
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
|
|
controller.logger.Warnf("failed to write exec frame to the client: %v", err)
|
|
|
|
return responder.Empty()
|
|
}
|
|
case <-time.After(controller.pingInterval):
|
|
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
|
|
if err := wsConn.Ping(pingCtx); err != nil {
|
|
controller.logger.Warnf("exec: failed to ping the client, "+
|
|
"connection might time out: %v", err)
|
|
}
|
|
|
|
pingCtxCancel()
|
|
case <-ctx.Done():
|
|
controller.logger.Warnf("client disconnected prematurely")
|
|
|
|
return responder.Empty()
|
|
}
|
|
}
|
|
}
|
|
|
|
var (
|
|
errExecSessionDetached = errors.New("exec session detached")
|
|
errExecSessionClosed = errors.New("exec session closed")
|
|
)
|
|
|
|
func (controller *Controller) readExecSessionFrames(
|
|
ctx context.Context,
|
|
wsConn *websocket.Conn,
|
|
session *execSession,
|
|
subscriber *execSessionSubscriber,
|
|
) error {
|
|
for {
|
|
var frame execstream.Frame
|
|
|
|
messageType, payloadBytes, err := wsConn.Read(ctx)
|
|
if err != nil {
|
|
var closeErr websocket.CloseError
|
|
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
|
|
return errExecSessionDetached
|
|
}
|
|
|
|
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
|
|
}
|
|
|
|
if messageType != websocket.MessageText {
|
|
continue
|
|
}
|
|
|
|
if err := json.Unmarshal(payloadBytes, &frame); err != nil {
|
|
return err
|
|
}
|
|
|
|
switch frame.Type {
|
|
case execstream.FrameTypeStdin:
|
|
if err := session.writeStdin(frame.Data); err != nil {
|
|
return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err)
|
|
}
|
|
case execstream.FrameTypeHistory:
|
|
if !session.policy.replayEnabled {
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
|
|
session.sendHistory(subscriber, frame.Watermark)
|
|
case execstream.FrameTypeAck:
|
|
if !session.policy.replayEnabled {
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
|
|
session.ack(frame.Watermark)
|
|
case execstream.FrameTypeDetach:
|
|
if !session.policy.replayEnabled {
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
|
|
return errExecSessionDetached
|
|
case execstream.FrameTypeClose:
|
|
if !session.policy.replayEnabled {
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
|
|
session.close()
|
|
|
|
return errExecSessionClosed
|
|
default:
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
}
|
|
}
|