207 lines
5.7 KiB
Go
207 lines
5.7 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/avast/retry-go/v5"
|
|
"github.com/cirruslabs/orchard/internal/controller/sshexec"
|
|
"github.com/cirruslabs/orchard/internal/execstream"
|
|
"github.com/cirruslabs/orchard/internal/responder"
|
|
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
"github.com/coder/websocket"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
|
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
|
return responder
|
|
}
|
|
|
|
// Retrieve and parse path and query parameters
|
|
name := ctx.Param("name")
|
|
|
|
command := ctx.Query("command")
|
|
if command == "" {
|
|
return responder.JSON(http.StatusBadRequest,
|
|
NewErrorResponse("\"command\" parameter cannot be empty"))
|
|
}
|
|
|
|
stdin := ctx.Query("stdin") == "true"
|
|
|
|
waitRaw := ctx.DefaultQuery("wait", "10")
|
|
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
|
if err != nil {
|
|
return responder.Code(http.StatusBadRequest)
|
|
}
|
|
|
|
// Look-up the VM
|
|
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
|
defer waitContextCancel()
|
|
|
|
vm, responderImpl := controller.waitForVM(waitContext, name)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
// Establish a port-forwarding connection to a VM's SSH port
|
|
portForwardConn, err := retry.NewWithData[net.Conn](
|
|
retry.Context(waitContext),
|
|
retry.DelayType(retry.FixedDelay),
|
|
retry.Delay(time.Second),
|
|
retry.Attempts(0),
|
|
retry.LastErrorOnly(true),
|
|
).Do(func() (net.Conn, error) {
|
|
return controller.portForwardConnection(ctx, waitContext, vm.Worker, vm.UID, 22)
|
|
})
|
|
if err != nil {
|
|
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
|
|
}
|
|
defer portForwardConn.Close()
|
|
|
|
// Establish an SSH connection to a VM
|
|
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
|
|
if err != nil {
|
|
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err))
|
|
}
|
|
defer exec.Close()
|
|
|
|
// Upgrade HTTP request to a WebSocket connection
|
|
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
if err != nil {
|
|
return responder.Error(err)
|
|
}
|
|
defer func() {
|
|
// Ensure that we always close the accepted WebSocket connection,
|
|
// otherwise resource leak is possible[1]
|
|
//
|
|
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
|
|
_ = wsConn.CloseNow()
|
|
}()
|
|
|
|
// Read WebSocket frames
|
|
readFramesErrCh := make(chan error, 1)
|
|
go func() {
|
|
readFramesErrCh <- controller.readFrames(ctx, wsConn, exec.Stdin())
|
|
}()
|
|
|
|
// Run the command
|
|
sshErrCh := make(chan error, 1)
|
|
outgoingFrames := make(chan *execstream.Frame)
|
|
go func() {
|
|
sshErrCh <- exec.Run(ctx, command, outgoingFrames)
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case readFramesErr := <-readFramesErrCh:
|
|
controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr)
|
|
|
|
return responder.Empty()
|
|
case outgoingFrame := <-outgoingFrames:
|
|
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
|
|
controller.logger.Warnf("failed to write WebSocket frame to the client: %v", err)
|
|
|
|
return responder.Empty()
|
|
}
|
|
case sshErr := <-sshErrCh:
|
|
if sshErr != nil {
|
|
if err := execstream.WriteFrame(ctx, wsConn, &execstream.Frame{
|
|
Type: execstream.FrameTypeError,
|
|
Error: sshErr.Error(),
|
|
}); err != nil {
|
|
controller.logger.Warnf("exec: failed to write error frame to WebSocket: %v", err)
|
|
}
|
|
}
|
|
|
|
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
|
|
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
|
|
}
|
|
|
|
if readFramesErrCh != nil {
|
|
// Read() on a WebSocket should unblock shortly after calling Close()
|
|
<-readFramesErrCh
|
|
}
|
|
|
|
return responder.Empty()
|
|
case <-time.After(controller.pingInterval):
|
|
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
|
|
if err := wsConn.Ping(pingCtx); err != nil {
|
|
controller.logger.Warnf("port forwarding: failed to ping the client, "+
|
|
"connection might time out: %v", err)
|
|
}
|
|
|
|
pingCtxCancel()
|
|
case <-ctx.Done():
|
|
controller.logger.Warnf("client disconnected prematurely")
|
|
|
|
return responder.Empty()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (controller *Controller) readFrames(
|
|
ctx context.Context,
|
|
wsConn *websocket.Conn,
|
|
stdinHandle io.WriteCloser,
|
|
) error {
|
|
for {
|
|
var frame execstream.Frame
|
|
|
|
messageType, payloadBytes, err := wsConn.Read(ctx)
|
|
if err != nil {
|
|
var closeErr websocket.CloseError
|
|
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
|
|
}
|
|
|
|
if messageType != websocket.MessageText {
|
|
continue
|
|
}
|
|
|
|
if err := json.Unmarshal(payloadBytes, &frame); err != nil {
|
|
return err
|
|
}
|
|
|
|
switch frame.Type {
|
|
case execstream.FrameTypeStdin:
|
|
if stdinHandle == nil {
|
|
return fmt.Errorf("failed to handle %q frame: this exec session "+
|
|
"has no stdin is enabled or already closed", frame.Type)
|
|
}
|
|
|
|
if len(frame.Data) == 0 {
|
|
if err := stdinHandle.Close(); err != nil {
|
|
return fmt.Errorf("failed to handle %q frame: failed to close "+
|
|
"stdin: %w", frame.Type, err)
|
|
}
|
|
|
|
stdinHandle = nil
|
|
|
|
continue
|
|
}
|
|
|
|
if _, err := stdinHandle.Write(frame.Data); err != nil {
|
|
return fmt.Errorf("failed to handle %q frame: failed to write "+
|
|
"to stdin: %w", frame.Type, err)
|
|
}
|
|
default:
|
|
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
|
}
|
|
}
|
|
}
|