orchard/internal/controller/api_vms_exec.go

207 lines
5.7 KiB
Go

package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"time"
"github.com/avast/retry-go/v5"
"github.com/cirruslabs/orchard/internal/controller/sshexec"
"github.com/cirruslabs/orchard/internal/execstream"
"github.com/cirruslabs/orchard/internal/responder"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/coder/websocket"
"github.com/gin-gonic/gin"
)
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
v1.ServiceAccountRoleComputeConnect); responder != nil {
return responder
}
// Retrieve and parse path and query parameters
name := ctx.Param("name")
command := ctx.Query("command")
if command == "" {
return responder.JSON(http.StatusBadRequest,
NewErrorResponse("\"command\" parameter cannot be empty"))
}
stdin := ctx.Query("stdin") == "true"
waitRaw := ctx.DefaultQuery("wait", "10")
wait, err := strconv.ParseUint(waitRaw, 10, 16)
if err != nil {
return responder.Code(http.StatusBadRequest)
}
// Look-up the VM
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitContextCancel()
vm, responderImpl := controller.waitForVM(waitContext, name)
if responderImpl != nil {
return responderImpl
}
// Establish a port-forwarding connection to a VM's SSH port
portForwardConn, err := retry.NewWithData[net.Conn](
retry.Context(waitContext),
retry.DelayType(retry.FixedDelay),
retry.Delay(time.Second),
retry.Attempts(0),
retry.LastErrorOnly(true),
).Do(func() (net.Conn, error) {
return controller.portForwardConnection(ctx, waitContext, vm.Worker, vm.UID, 22)
})
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
}
defer portForwardConn.Close()
// Establish an SSH connection to a VM
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err))
}
defer exec.Close()
// Upgrade HTTP request to a WebSocket connection
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
return responder.Error(err)
}
defer func() {
// Ensure that we always close the accepted WebSocket connection,
// otherwise resource leak is possible[1]
//
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
_ = wsConn.CloseNow()
}()
// Read WebSocket frames
readFramesErrCh := make(chan error, 1)
go func() {
readFramesErrCh <- controller.readFrames(ctx, wsConn, exec.Stdin())
}()
// Run the command
sshErrCh := make(chan error, 1)
outgoingFrames := make(chan *execstream.Frame)
go func() {
sshErrCh <- exec.Run(ctx, command, outgoingFrames)
}()
for {
select {
case readFramesErr := <-readFramesErrCh:
controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr)
return responder.Empty()
case outgoingFrame := <-outgoingFrames:
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
controller.logger.Warnf("failed to write WebSocket frame to the client: %v", err)
return responder.Empty()
}
case sshErr := <-sshErrCh:
if sshErr != nil {
if err := execstream.WriteFrame(ctx, wsConn, &execstream.Frame{
Type: execstream.FrameTypeError,
Error: sshErr.Error(),
}); err != nil {
controller.logger.Warnf("exec: failed to write error frame to WebSocket: %v", err)
}
}
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
}
if readFramesErrCh != nil {
// Read() on a WebSocket should unblock shortly after calling Close()
<-readFramesErrCh
}
return responder.Empty()
case <-time.After(controller.pingInterval):
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
if err := wsConn.Ping(pingCtx); err != nil {
controller.logger.Warnf("port forwarding: failed to ping the client, "+
"connection might time out: %v", err)
}
pingCtxCancel()
case <-ctx.Done():
controller.logger.Warnf("client disconnected prematurely")
return responder.Empty()
}
}
}
func (controller *Controller) readFrames(
ctx context.Context,
wsConn *websocket.Conn,
stdinHandle io.WriteCloser,
) error {
for {
var frame execstream.Frame
messageType, payloadBytes, err := wsConn.Read(ctx)
if err != nil {
var closeErr websocket.CloseError
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
return nil
}
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
}
if messageType != websocket.MessageText {
continue
}
if err := json.Unmarshal(payloadBytes, &frame); err != nil {
return err
}
switch frame.Type {
case execstream.FrameTypeStdin:
if stdinHandle == nil {
return fmt.Errorf("failed to handle %q frame: this exec session "+
"has no stdin is enabled or already closed", frame.Type)
}
if len(frame.Data) == 0 {
if err := stdinHandle.Close(); err != nil {
return fmt.Errorf("failed to handle %q frame: failed to close "+
"stdin: %w", frame.Type, err)
}
stdinHandle = nil
continue
}
if _, err := stdinHandle.Write(frame.Data); err != nil {
return fmt.Errorf("failed to handle %q frame: failed to write "+
"to stdin: %w", frame.Type, err)
}
default:
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
}
}