controller(api): new "GET /vms/{name}/exec" WebSocket-based endpoint
This commit is contained in:
parent
c4b7378883
commit
5111950f11
175
api/openapi.yaml
175
api/openapi.yaml
|
|
@ -415,6 +415,77 @@ paths:
|
|||
description: VM resource with the given name doesn't exist
|
||||
'503':
|
||||
description: Failed to establish connection with the worker responsible for the specified VM
|
||||
/vms/{name}/exec:
|
||||
parameters:
|
||||
- in: path
|
||||
name: name
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
get:
|
||||
summary: "Execute a command inside a VM using WebSocket protocol"
|
||||
tags:
|
||||
- vms
|
||||
parameters:
|
||||
- in: query
|
||||
name: command
|
||||
description: Command to execute
|
||||
schema:
|
||||
type: string
|
||||
minLength: 1
|
||||
required: true
|
||||
- in: query
|
||||
name: stdin
|
||||
description: |
|
||||
Whether to allocate an interactive standard input for the command
|
||||
|
||||
When enabled, make sure to close the standard input by sending a `ExecClientFrameStdin`
|
||||
frame with an empty data. Otherwise the command might never terminate waiting for the
|
||||
standard input to end.
|
||||
schema:
|
||||
type: boolean
|
||||
default: false
|
||||
required: false
|
||||
- in: query
|
||||
name: wait
|
||||
description: Duration in seconds for the VM to become available if it's not available already
|
||||
schema:
|
||||
type: integer
|
||||
minimum: 0
|
||||
maximum: 65535
|
||||
default: 10
|
||||
required: false
|
||||
- in: header
|
||||
name: Connection
|
||||
description: WebSocket protocol required header
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- in: header
|
||||
name: Upgrade
|
||||
description: WebSocket protocol required header
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'101':
|
||||
description: |
|
||||
The connection has been upgraded to WebSocket. Messages exchanged after upgrade:
|
||||
|
||||
* Orchard Client → Orchard Controller: `ExecClientFrame`
|
||||
* Orchard Controller → Orchard Client : `ExecControllerFrame`
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ExecClientFrame'
|
||||
- $ref: '#/components/schemas/ExecControllerFrame'
|
||||
'400':
|
||||
description: Invalid parameters were supplied
|
||||
'404':
|
||||
description: VM resource with the given name doesn't exist
|
||||
'503':
|
||||
description: Controller failed to establish a connection with the VM
|
||||
/vms/{name}/ip:
|
||||
parameters:
|
||||
- in: path
|
||||
|
|
@ -693,6 +764,110 @@ components:
|
|||
ip:
|
||||
type: string
|
||||
description: The resolved IP address
|
||||
ExecClientFrame:
|
||||
description: WebSocket frame from Orchard Client to the Orchard Controller
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ExecClientFrameStdin'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
stdin: '#/components/schemas/ExecClientFrameStdin'
|
||||
ExecClientFrameStdin:
|
||||
description: Send bytes to the process standard input
|
||||
type: object
|
||||
required: [ type, data ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ stdin ]
|
||||
data:
|
||||
type: string
|
||||
format: byte
|
||||
description: |
|
||||
Base64-encoded standard input bytes to the process
|
||||
|
||||
Empty payload indicates EOF and causes standard input to be closed.
|
||||
example:
|
||||
type: stdin
|
||||
data: aGVsbG8K
|
||||
ExecControllerFrame:
|
||||
description: WebSocket frame from Orchard Controller to the Orchard Client
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ExecControllerFrameStdout'
|
||||
- $ref: '#/components/schemas/ExecControllerFrameStderr'
|
||||
- $ref: '#/components/schemas/ExecControllerFrameExit'
|
||||
- $ref: '#/components/schemas/ExecControllerFrameError'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
stdout: '#/components/schemas/ExecControllerFrameStdout'
|
||||
stderr: '#/components/schemas/ExecControllerFrameStderr'
|
||||
exit: '#/components/schemas/ExecControllerFrameExit'
|
||||
error: '#/components/schemas/ExecControllerFrameError'
|
||||
ExecControllerFrameStdout:
|
||||
description: Standard output from the process
|
||||
type: object
|
||||
required: [ type, data ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ stdout ]
|
||||
data:
|
||||
type: string
|
||||
format: byte
|
||||
description: Base64-encoded standard output bytes from the process
|
||||
example:
|
||||
type: stdout
|
||||
data: aGVsbG8K
|
||||
ExecControllerFrameStderr:
|
||||
description: Standard error from the process
|
||||
type: object
|
||||
required: [ type, data ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ stderr ]
|
||||
data:
|
||||
type: string
|
||||
format: byte
|
||||
description: Base64-encoded standard error bytes from the process
|
||||
example:
|
||||
type: stderr
|
||||
data: aGVsbG8K
|
||||
ExecControllerFrameExit:
|
||||
description: Process termination details
|
||||
type: object
|
||||
required: [ type, exit ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ exit ]
|
||||
exit:
|
||||
type: object
|
||||
required: [ code ]
|
||||
properties:
|
||||
code:
|
||||
type: integer
|
||||
format: int32
|
||||
description: Process exit code
|
||||
example:
|
||||
type: exit
|
||||
exit:
|
||||
code: 0
|
||||
ExecControllerFrameError:
|
||||
description: Error message encountered while running the process
|
||||
type: object
|
||||
required: [ type, error ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ error ]
|
||||
error:
|
||||
type: string
|
||||
description: Error message text
|
||||
example:
|
||||
type: error
|
||||
error: Failed to establish SSH connection to a VM
|
||||
Event:
|
||||
title: Generic Resource Event
|
||||
type: object
|
||||
|
|
|
|||
|
|
@ -171,6 +171,9 @@ func (controller *Controller) initAPI() *gin.Engine {
|
|||
v1.GET("/vms/:name/port-forward", func(c *gin.Context) {
|
||||
controller.portForwardVM(c).Respond(c)
|
||||
})
|
||||
v1.GET("/vms/:name/exec", func(c *gin.Context) {
|
||||
controller.execVM(c).Respond(c)
|
||||
})
|
||||
v1.GET("/vms/:name/ip", func(c *gin.Context) {
|
||||
controller.ip(c).Respond(c)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,347 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
"github.com/cirruslabs/orchard/internal/responder"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
||||
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
||||
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
||||
return responder
|
||||
}
|
||||
|
||||
// Retrieve and parse path and query parameters
|
||||
name := ctx.Param("name")
|
||||
|
||||
command := ctx.Query("command")
|
||||
if command == "" {
|
||||
return responder.JSON(http.StatusBadRequest,
|
||||
NewErrorResponse("\"command\" parameter cannot be empty"))
|
||||
}
|
||||
|
||||
stdin := ctx.Query("stdin") == "true"
|
||||
|
||||
waitRaw := ctx.DefaultQuery("wait", "10")
|
||||
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
||||
if err != nil {
|
||||
return responder.Code(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Look-up the VM
|
||||
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
||||
defer waitContextCancel()
|
||||
|
||||
vm, responderImpl := controller.waitForVM(waitContext, name)
|
||||
if responderImpl != nil {
|
||||
return responderImpl
|
||||
}
|
||||
|
||||
// Establish a port-forwarding connection to a VM's SSH port
|
||||
portForwardConn, portForwardCancel, err := controller.portForwardConnection(ctx, waitContext,
|
||||
vm.Worker, vm.UID, 22)
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
|
||||
}
|
||||
defer portForwardCancel()
|
||||
|
||||
// Upgrade HTTP request to a WebSocket connection
|
||||
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
})
|
||||
if err != nil {
|
||||
return responder.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
// Ensure that we always close the accepted WebSocket connection,
|
||||
// otherwise resource leak is possible[1]
|
||||
//
|
||||
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
|
||||
_ = wsConn.CloseNow()
|
||||
}()
|
||||
|
||||
// Start a goroutine that establishes an SSH connection to a VM and runs a command
|
||||
sshErrCh := make(chan error, 1)
|
||||
stdinHandleCh := make(chan io.WriteCloser, 1)
|
||||
outgoingFrames := make(chan *execstream.Frame)
|
||||
go func() {
|
||||
sshErrCh <- controller.execSSH(ctx, portForwardConn, vm, stdin, stdinHandleCh, command, outgoingFrames)
|
||||
}()
|
||||
|
||||
var readFramesErrCh chan error
|
||||
|
||||
for {
|
||||
select {
|
||||
case stdinHandle := <-stdinHandleCh:
|
||||
// SSH session is almost up, we have the standard input handle,
|
||||
// so we can start a goroutine that reads WebSocket frames
|
||||
readFramesErrCh = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
readFramesErrCh <- controller.readFrames(ctx, wsConn, stdinHandle)
|
||||
}()
|
||||
case readFramesErr := <-readFramesErrCh:
|
||||
controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr)
|
||||
|
||||
return responder.Empty()
|
||||
case outgoingFrame := <-outgoingFrames:
|
||||
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
|
||||
controller.logger.Warnf("failed to write WebSocket frame to the client: %v", err)
|
||||
|
||||
return responder.Empty()
|
||||
}
|
||||
case sshErr := <-sshErrCh:
|
||||
if sshErr != nil {
|
||||
if err := execstream.WriteFrame(ctx, wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeError,
|
||||
Error: sshErr.Error(),
|
||||
}); err != nil {
|
||||
controller.logger.Warnf("exec: failed to write error frame to WebSocket: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
|
||||
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
|
||||
}
|
||||
|
||||
if readFramesErrCh != nil {
|
||||
// Read() on a WebSocket should unblock shortly after calling Close()
|
||||
<-readFramesErrCh
|
||||
}
|
||||
|
||||
return responder.Empty()
|
||||
case <-time.After(controller.pingInterval):
|
||||
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
|
||||
if err := wsConn.Ping(pingCtx); err != nil {
|
||||
controller.logger.Warnf("port forwarding: failed to ping the client, "+
|
||||
"connection might time out: %v", err)
|
||||
}
|
||||
|
||||
pingCtxCancel()
|
||||
case <-ctx.Done():
|
||||
controller.logger.Warnf("client disconnected prematurely")
|
||||
|
||||
return responder.Empty()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (controller *Controller) readFrames(
|
||||
ctx context.Context,
|
||||
wsConn *websocket.Conn,
|
||||
stdinHandle io.WriteCloser,
|
||||
) error {
|
||||
for {
|
||||
var frame execstream.Frame
|
||||
|
||||
messageType, payloadBytes, err := wsConn.Read(ctx)
|
||||
if err != nil {
|
||||
var closeErr websocket.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
|
||||
}
|
||||
|
||||
if messageType != websocket.MessageText {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(payloadBytes, &frame); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch frame.Type {
|
||||
case execstream.FrameTypeStdin:
|
||||
if stdinHandle == nil {
|
||||
return fmt.Errorf("failed to handle %q frame: this exec session "+
|
||||
"has no stdin is enabled or already closed", frame.Type)
|
||||
}
|
||||
|
||||
if len(frame.Data) == 0 {
|
||||
if err := stdinHandle.Close(); err != nil {
|
||||
return fmt.Errorf("failed to handle %q frame: failed to close "+
|
||||
"stdin: %w", frame.Type, err)
|
||||
}
|
||||
|
||||
stdinHandle = nil
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := stdinHandle.Write(frame.Data); err != nil {
|
||||
return fmt.Errorf("failed to handle %q frame: failed to write "+
|
||||
"to stdin: %w", frame.Type, err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (controller *Controller) execSSH(
|
||||
ctx context.Context,
|
||||
portForwardConn net.Conn,
|
||||
vm *v1.VM,
|
||||
stdin bool,
|
||||
stdinHandleCh chan<- io.WriteCloser,
|
||||
command string,
|
||||
outgoingFrames chan<- *execstream.Frame,
|
||||
) error {
|
||||
// Establish an SSH connection
|
||||
sshConn, sshChans, sshReqs, err := ssh.NewClientConn(portForwardConn, "", &ssh.ClientConfig{
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
User: vm.SSHUsername(),
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(vm.SSHPassword()),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to a new SSH connection: %w", err)
|
||||
}
|
||||
|
||||
sshClient := ssh.NewClient(sshConn, sshChans, sshReqs)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Create a new SSH session
|
||||
sshSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create a new SSH session: %w", err)
|
||||
}
|
||||
defer sshSession.Close()
|
||||
|
||||
if stdin {
|
||||
stdin, err := sshSession.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create standard input pipe: %w", err)
|
||||
}
|
||||
|
||||
stdinHandleCh <- stdin
|
||||
} else {
|
||||
stdinHandleCh <- nil
|
||||
}
|
||||
|
||||
stdout, err := sshSession.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create standard output pipe: %w", err)
|
||||
}
|
||||
|
||||
stderr, err := sshSession.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create standard error pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := sshSession.Start(command); err != nil {
|
||||
return fmt.Errorf("failed to start command %q: %w", command, err)
|
||||
}
|
||||
|
||||
// Read bytes from standard output and standard error and stream them as frames
|
||||
ioGroup, ioGroupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
ioGroup.Go(func() error {
|
||||
return ioStreamReader(ioGroupCtx, stdout, execstream.FrameTypeStdout, outgoingFrames)
|
||||
})
|
||||
ioGroup.Go(func() error {
|
||||
return ioStreamReader(ioGroupCtx, stderr, execstream.FrameTypeStderr, outgoingFrames)
|
||||
})
|
||||
|
||||
sshWaitErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
sshWaitErrCh <- sshSession.Wait()
|
||||
}()
|
||||
|
||||
// Wait for SSH command terminate while respecting context
|
||||
var sshWaitErr error
|
||||
|
||||
select {
|
||||
case sshWaitErr = <-sshWaitErrCh:
|
||||
// Proceed
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Wait for the I/O to complete, otherwise we may
|
||||
// miss some bits of the command's output/error
|
||||
if err := ioGroup.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Post an exit event
|
||||
exitFrame := &execstream.Frame{
|
||||
Type: execstream.FrameTypeExit,
|
||||
Exit: execstream.Exit{
|
||||
Code: 0,
|
||||
},
|
||||
}
|
||||
|
||||
if sshWaitErr != nil {
|
||||
var sshExitError *ssh.ExitError
|
||||
if errors.As(sshWaitErr, &sshExitError) {
|
||||
exitFrame.Exit.Code = int32(sshExitError.ExitStatus())
|
||||
} else {
|
||||
return fmt.Errorf("failed to execute command %q: %w", command, sshWaitErr)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case outgoingFrames <- exitFrame:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func ioStreamReader(
|
||||
ctx context.Context,
|
||||
r io.Reader,
|
||||
frameType execstream.FrameType,
|
||||
ch chan<- *execstream.Frame,
|
||||
) error {
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
|
||||
if n > 0 {
|
||||
frame := &execstream.Frame{
|
||||
Type: frameType,
|
||||
Data: slices.Clone(buf[:n]),
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- frame:
|
||||
// Proceed
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ package controller
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
|
@ -44,10 +45,11 @@ func (controller *Controller) portForwardVM(ctx *gin.Context) responder.Responde
|
|||
if err != nil {
|
||||
return responder.Code(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Look-up the VM
|
||||
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
||||
defer waitContextCancel()
|
||||
|
||||
// Look-up the VM
|
||||
vm, responderImpl := controller.waitForVM(waitContext, name)
|
||||
if responderImpl != nil {
|
||||
return responderImpl
|
||||
|
|
@ -201,3 +203,56 @@ func (controller *Controller) waitForVM(ctx context.Context, name string) (*v1.V
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (controller *Controller) portForwardConnection(
|
||||
ctx context.Context,
|
||||
notifyContext context.Context,
|
||||
workerName string,
|
||||
vmUID string,
|
||||
port uint32,
|
||||
) (net.Conn, context.CancelFunc, error) {
|
||||
// Create a rendezvous connection point
|
||||
rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx)
|
||||
|
||||
session := uuid.New().String()
|
||||
boomerangConnCh, boomerangConnCancel := controller.connRendezvous.Request(rendezvousCtx, session)
|
||||
|
||||
cancel := func() {
|
||||
boomerangConnCancel()
|
||||
rendezvousCtxCancel()
|
||||
}
|
||||
|
||||
// Send request to a worker to initiate a port forwarding connection back to us
|
||||
err := controller.workerNotifier.Notify(notifyContext, workerName, &rpc.WatchInstruction{
|
||||
Action: &rpc.WatchInstruction_PortForwardAction{
|
||||
PortForwardAction: &rpc.WatchInstruction_PortForward{
|
||||
Session: session,
|
||||
VmUid: vmUID,
|
||||
Port: port,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
cancel()
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to request port forwarding from the worker %s: %v",
|
||||
workerName, err)
|
||||
}
|
||||
|
||||
// Wait for the worker to respond
|
||||
select {
|
||||
case rendezvousResponse := <-boomerangConnCh:
|
||||
if rendezvousResponse.ErrorMessage != "" {
|
||||
cancel()
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to establish port forwarding session on the worker: %s",
|
||||
rendezvousResponse.ErrorMessage)
|
||||
}
|
||||
|
||||
return rendezvousResponse.Result, cancel, nil
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
package execstream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
type FrameType string
|
||||
|
||||
const (
|
||||
FrameTypeStdin FrameType = "stdin"
|
||||
FrameTypeStdout FrameType = "stdout"
|
||||
FrameTypeStderr FrameType = "stderr"
|
||||
FrameTypeExit FrameType = "exit"
|
||||
FrameTypeError FrameType = "error"
|
||||
)
|
||||
|
||||
type Frame struct {
|
||||
Type FrameType `json:"type"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
Exit Exit `json:"exit,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type Exit struct {
|
||||
Code int32 `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
func WriteFrame(ctx context.Context, wsConn *websocket.Conn, frame *Frame) error {
|
||||
frameBytes, err := json.Marshal(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return wsConn.Write(ctx, websocket.MessageText, frameBytes)
|
||||
}
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
"github.com/cirruslabs/orchard/internal/imageconstant"
|
||||
"github.com/cirruslabs/orchard/internal/tests/devcontroller"
|
||||
"github.com/cirruslabs/orchard/internal/tests/wait"
|
||||
"github.com/cirruslabs/orchard/pkg/client"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/coder/websocket"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVMExecWithoutStdin(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
|
||||
// Run a command
|
||||
wsConn, err := devClient.VMs().Exec(t.Context(), vmName, "/bin/echo -n 'Hello, World!'",
|
||||
false, 30)
|
||||
require.NoError(t, err)
|
||||
defer wsConn.CloseNow()
|
||||
|
||||
// Ensure that the command outputs "Hello, World!" and terminates successfully
|
||||
frame := readFrame(t, wsConn)
|
||||
require.Equal(t, execstream.FrameTypeStdout, frame.Type)
|
||||
require.Equal(t, "Hello, World!", string(frame.Data))
|
||||
|
||||
frame = readFrame(t, wsConn)
|
||||
require.Equal(t, execstream.FrameTypeExit, frame.Type)
|
||||
require.EqualValues(t, 0, frame.Exit.Code)
|
||||
|
||||
// Ensure that Orchard Controller gracefully terminates the WebSocket connection
|
||||
_, _, err = wsConn.Read(t.Context())
|
||||
var closeError websocket.CloseError
|
||||
require.ErrorAs(t, err, &closeError)
|
||||
require.Equal(t, websocket.StatusNormalClosure, closeError.Code)
|
||||
}
|
||||
|
||||
func TestVMExecWithStdin(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
|
||||
// Run a command
|
||||
wsConn, err := devClient.VMs().Exec(t.Context(), vmName, "/bin/cat", true, 30)
|
||||
require.NoError(t, err)
|
||||
defer wsConn.CloseNow()
|
||||
|
||||
// Populate and close the command's standard input
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeStdin,
|
||||
Data: []byte("Hello, World!\n"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeStdin,
|
||||
Data: []byte{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure that the command outputs "Hello, World!\n" and terminates successfully
|
||||
frame := readFrame(t, wsConn)
|
||||
require.Equal(t, execstream.FrameTypeStdout, frame.Type)
|
||||
require.Equal(t, "Hello, World!\n", string(frame.Data))
|
||||
|
||||
frame = readFrame(t, wsConn)
|
||||
require.Equal(t, execstream.FrameTypeExit, frame.Type)
|
||||
require.EqualValues(t, 0, frame.Exit.Code)
|
||||
|
||||
// Ensure that Orchard Controller gracefully terminates the WebSocket connection
|
||||
_, _, err = wsConn.Read(t.Context())
|
||||
var closeError websocket.CloseError
|
||||
require.ErrorAs(t, err, &closeError)
|
||||
require.Equal(t, websocket.StatusNormalClosure, closeError.Code)
|
||||
}
|
||||
|
||||
func TestVMExecScript(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
|
||||
script := "sh -c 'echo stdout-line; echo stderr-line >&2; IFS= read -r line; echo stdin:$line; exit 7'"
|
||||
|
||||
wsConn, err := devClient.VMs().Exec(t.Context(), vmName, script, true, 30)
|
||||
require.NoError(t, err)
|
||||
defer wsConn.CloseNow()
|
||||
|
||||
// Populate and close the command's standard input
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeStdin,
|
||||
Data: []byte("hello-from-test\n"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeStdin,
|
||||
Data: []byte{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Collect output and wait for command's exit
|
||||
var stdout, stderr bytes.Buffer
|
||||
var exitFrame *execstream.Frame
|
||||
|
||||
for exitFrame == nil {
|
||||
frame := readFrame(t, wsConn)
|
||||
|
||||
switch frame.Type {
|
||||
case execstream.FrameTypeStdout:
|
||||
stdout.Write(frame.Data)
|
||||
case execstream.FrameTypeStderr:
|
||||
stderr.Write(frame.Data)
|
||||
case execstream.FrameTypeExit:
|
||||
exitFrame = frame
|
||||
default:
|
||||
t.Fatalf("unexpected frame type %q", frame.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that we've observed everything as per script
|
||||
require.EqualValues(t, 7, exitFrame.Exit.Code)
|
||||
require.Equal(t, "stdout-line\nstdin:hello-from-test\n", stdout.String())
|
||||
require.Equal(t, "stderr-line\n", stderr.String())
|
||||
|
||||
// Ensure that Orchard Controller gracefully terminates the WebSocket connection
|
||||
_, _, err = wsConn.Read(t.Context())
|
||||
var closeError websocket.CloseError
|
||||
require.ErrorAs(t, err, &closeError)
|
||||
require.Equal(t, websocket.StatusNormalClosure, closeError.Code)
|
||||
}
|
||||
|
||||
func prepareForExec(t *testing.T) (*client.Client, string) {
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
vmName := "test-vm-exec-" + uuid.NewString()
|
||||
|
||||
err := devClient.VMs().Create(t.Context(), &v1.VM{
|
||||
Meta: v1.Meta{
|
||||
Name: vmName,
|
||||
},
|
||||
Image: imageconstant.DefaultMacosImage,
|
||||
Headless: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
vm, err := devClient.VMs().Get(t.Context(), vmName)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Waiting for the VM to start. Current status: %s", vm.Status)
|
||||
|
||||
return vm.Status == v1.VMStatusRunning
|
||||
}), "failed to start a VM")
|
||||
|
||||
return devClient, vmName
|
||||
}
|
||||
|
||||
func readFrame(t *testing.T, wsConn *websocket.Conn) *execstream.Frame {
|
||||
var frame execstream.Frame
|
||||
|
||||
messageType, payloadBytes, err := wsConn.Read(t.Context())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, websocket.MessageText, messageType)
|
||||
|
||||
err = json.Unmarshal(payloadBytes, &frame)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &frame
|
||||
}
|
||||
|
|
@ -286,6 +286,19 @@ func (client *Client) wsRequest(
|
|||
path string,
|
||||
params map[string]string,
|
||||
) (net.Conn, error) {
|
||||
wsConn, err := client.wsRequestRaw(ctx, path, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return websocket.NetConn(ctx, wsConn, websocket.MessageBinary), nil
|
||||
}
|
||||
|
||||
func (client *Client) wsRequestRaw(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
params map[string]string,
|
||||
) (*websocket.Conn, error) {
|
||||
endpointURL := client.formatPath(path)
|
||||
|
||||
// Adapt HTTP scheme to WebSocket scheme
|
||||
|
|
@ -321,7 +334,7 @@ func (client *Client) wsRequest(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (client *Client) formatPath(path string) *url.URL {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
type VMsService struct {
|
||||
|
|
@ -156,6 +157,21 @@ func (service *VMsService) PortForward(
|
|||
})
|
||||
}
|
||||
|
||||
func (service *VMsService) Exec(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
command string,
|
||||
stdin bool,
|
||||
waitSeconds uint16,
|
||||
) (*websocket.Conn, error) {
|
||||
return service.client.wsRequestRaw(ctx, fmt.Sprintf("vms/%s/exec", url.PathEscape(name)),
|
||||
map[string]string{
|
||||
"command": command,
|
||||
"stdin": strconv.FormatBool(stdin),
|
||||
"wait": strconv.FormatUint(uint64(waitSeconds), 10),
|
||||
})
|
||||
}
|
||||
|
||||
func (service *VMsService) IP(ctx context.Context, name string, waitSeconds uint16) (string, error) {
|
||||
result := struct {
|
||||
IP string `json:"ip"`
|
||||
|
|
|
|||
|
|
@ -111,6 +111,22 @@ func (vm *VM) Match(filter Filter) bool {
|
|||
}
|
||||
}
|
||||
|
||||
func (vm *VM) SSHUsername() string {
|
||||
if vm.Username != "" {
|
||||
return vm.Username
|
||||
}
|
||||
|
||||
return "admin"
|
||||
}
|
||||
|
||||
func (vm *VM) SSHPassword() string {
|
||||
if vm.Password != "" {
|
||||
return vm.Password
|
||||
}
|
||||
|
||||
return "admin"
|
||||
}
|
||||
|
||||
func (vm *VM) IsScheduled() bool {
|
||||
if ConditionExists(vm.Conditions, ConditionTypeScheduled) {
|
||||
return ConditionIsTrue(vm.Conditions, ConditionTypeScheduled)
|
||||
|
|
|
|||
Loading…
Reference in New Issue