Supporting reconnecting to `/exec` socket (#434)
This commit is contained in:
parent
88506b1adb
commit
e6a3314f58
121
api/openapi.yaml
121
api/openapi.yaml
|
|
@ -429,11 +429,33 @@ paths:
|
|||
parameters:
|
||||
- in: query
|
||||
name: command
|
||||
description: Command to execute
|
||||
description: |
|
||||
Command to execute.
|
||||
|
||||
Required when starting a new exec session. May be omitted when reconnecting to an
|
||||
existing session identified by the `session` parameter.
|
||||
schema:
|
||||
type: string
|
||||
minLength: 1
|
||||
required: true
|
||||
required: false
|
||||
- in: query
|
||||
name: session
|
||||
description: |
|
||||
Optional stable exec session identifier. When present, websocket disconnects detach
|
||||
from the command instead of terminating it, and later requests with the same VM name
|
||||
and session id may reconnect and request buffered history.
|
||||
schema:
|
||||
type: string
|
||||
minLength: 1
|
||||
required: false
|
||||
- in: query
|
||||
name: cmux_session_id
|
||||
description: |
|
||||
Compatibility alias for `session`. Prefer `session` for new Orchard clients.
|
||||
schema:
|
||||
type: string
|
||||
minLength: 1
|
||||
required: false
|
||||
- in: query
|
||||
name: stdin
|
||||
description: |
|
||||
|
|
@ -483,7 +505,9 @@ paths:
|
|||
'400':
|
||||
description: Invalid parameters were supplied
|
||||
'404':
|
||||
description: VM resource with the given name doesn't exist
|
||||
description: VM resource with the given name or reconnectable exec session doesn't exist
|
||||
'409':
|
||||
description: Reconnectable exec session already exists with a different command
|
||||
'503':
|
||||
description: Controller failed to establish a connection with the VM
|
||||
/vms/{name}/ip:
|
||||
|
|
@ -800,10 +824,18 @@ components:
|
|||
description: WebSocket frame from Orchard Client to the Orchard Controller
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/ExecClientFrameStdin'
|
||||
- $ref: '#/components/schemas/ExecClientFrameHistory'
|
||||
- $ref: '#/components/schemas/ExecClientFrameAck'
|
||||
- $ref: '#/components/schemas/ExecClientFrameDetach'
|
||||
- $ref: '#/components/schemas/ExecClientFrameClose'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
stdin: '#/components/schemas/ExecClientFrameStdin'
|
||||
history: '#/components/schemas/ExecClientFrameHistory'
|
||||
ack: '#/components/schemas/ExecClientFrameAck'
|
||||
detach: '#/components/schemas/ExecClientFrameDetach'
|
||||
close: '#/components/schemas/ExecClientFrameClose'
|
||||
ExecClientFrameStdin:
|
||||
description: Send bytes to the process standard input
|
||||
type: object
|
||||
|
|
@ -822,6 +854,56 @@ components:
|
|||
example:
|
||||
type: stdin
|
||||
data: aGVsbG8K
|
||||
ExecClientFrameHistory:
|
||||
description: Request buffered output strictly newer than the supplied watermark
|
||||
type: object
|
||||
required: [ type, watermark ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ history ]
|
||||
watermark:
|
||||
type: integer
|
||||
format: int64
|
||||
minimum: 0
|
||||
example:
|
||||
type: history
|
||||
watermark: 42
|
||||
ExecClientFrameAck:
|
||||
description: Acknowledge that output has been durably consumed through the supplied watermark
|
||||
type: object
|
||||
required: [ type, watermark ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ ack ]
|
||||
watermark:
|
||||
type: integer
|
||||
format: int64
|
||||
minimum: 0
|
||||
example:
|
||||
type: ack
|
||||
watermark: 42
|
||||
ExecClientFrameDetach:
|
||||
description: Detach this websocket while leaving the remote command running
|
||||
type: object
|
||||
required: [ type ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ detach ]
|
||||
example:
|
||||
type: detach
|
||||
ExecClientFrameClose:
|
||||
description: Close the reconnectable exec session and terminate the remote command
|
||||
type: object
|
||||
required: [ type ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ close ]
|
||||
example:
|
||||
type: close
|
||||
ExecControllerFrame:
|
||||
description: WebSocket frame from Orchard Controller to the Orchard Client
|
||||
oneOf:
|
||||
|
|
@ -829,6 +911,7 @@ components:
|
|||
- $ref: '#/components/schemas/ExecControllerFrameStderr'
|
||||
- $ref: '#/components/schemas/ExecControllerFrameExit'
|
||||
- $ref: '#/components/schemas/ExecControllerFrameError'
|
||||
- $ref: '#/components/schemas/ExecControllerFrameNoMoreHistory'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
|
|
@ -836,6 +919,7 @@ components:
|
|||
stderr: '#/components/schemas/ExecControllerFrameStderr'
|
||||
exit: '#/components/schemas/ExecControllerFrameExit'
|
||||
error: '#/components/schemas/ExecControllerFrameError'
|
||||
no_more_history: '#/components/schemas/ExecControllerFrameNoMoreHistory'
|
||||
ExecControllerFrameStdout:
|
||||
description: Standard output from the process
|
||||
type: object
|
||||
|
|
@ -848,6 +932,10 @@ components:
|
|||
type: string
|
||||
format: byte
|
||||
description: Base64-encoded standard output bytes from the process
|
||||
watermark:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Monotonic output watermark present on reconnectable sessions
|
||||
example:
|
||||
type: stdout
|
||||
data: aGVsbG8K
|
||||
|
|
@ -863,6 +951,10 @@ components:
|
|||
type: string
|
||||
format: byte
|
||||
description: Base64-encoded standard error bytes from the process
|
||||
watermark:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Monotonic output watermark present on reconnectable sessions
|
||||
example:
|
||||
type: stderr
|
||||
data: aGVsbG8K
|
||||
|
|
@ -882,6 +974,10 @@ components:
|
|||
type: integer
|
||||
format: int32
|
||||
description: Process exit code
|
||||
watermark:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Monotonic output watermark present on reconnectable sessions
|
||||
example:
|
||||
type: exit
|
||||
exit:
|
||||
|
|
@ -897,9 +993,28 @@ components:
|
|||
error:
|
||||
type: string
|
||||
description: Error message text
|
||||
watermark:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Monotonic output watermark present on reconnectable sessions
|
||||
example:
|
||||
type: error
|
||||
error: Failed to establish SSH connection to a VM
|
||||
ExecControllerFrameNoMoreHistory:
|
||||
description: Marker indicating that the requested replay range has been fully sent
|
||||
type: object
|
||||
required: [ type, watermark ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [ no_more_history ]
|
||||
watermark:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Highest watermark known to the exec session at the time of replay
|
||||
example:
|
||||
type: no_more_history
|
||||
watermark: 42
|
||||
Event:
|
||||
title: Generic Resource Event
|
||||
type: object
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ var noExperimentalRPCV2 bool
|
|||
var experimentalPingInterval time.Duration
|
||||
var experimentalDisableDBCompression bool
|
||||
var workerOfflineTimeout time.Duration
|
||||
var execSessionExitTTL time.Duration
|
||||
var synthetic bool
|
||||
|
||||
func newRunCommand() *cobra.Command {
|
||||
|
|
@ -87,6 +88,8 @@ func newRunCommand() *cobra.Command {
|
|||
"duration (e.g. 60s or 5m30s) after which a worker is considered offline for the purposes "+
|
||||
"of scheduling (no new VMs will be scheduled on such worker and already assigned VMs will be "+
|
||||
"marked as failed)")
|
||||
cmd.Flags().DurationVar(&execSessionExitTTL, "exec-session-exit-ttl", 10*time.Minute,
|
||||
"duration to retain reconnectable exec session history after the command exits")
|
||||
|
||||
// Hidden flags
|
||||
cmd.Flags().BoolVar(&synthetic, "synthetic", false, "")
|
||||
|
|
@ -147,6 +150,7 @@ func runController(cmd *cobra.Command, args []string) (err error) {
|
|||
controller.WithListenAddr(address),
|
||||
controller.WithDataDir(dataDir),
|
||||
controller.WithWorkerOfflineTimeout(workerOfflineTimeout),
|
||||
controller.WithExecSessionExitTTL(execSessionExitTTL),
|
||||
controller.WithLogger(logger),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
|
@ -28,9 +27,13 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
|
||||
// Retrieve and parse path and query parameters
|
||||
name := ctx.Param("name")
|
||||
sessionID := ctx.Query("session")
|
||||
if sessionID == "" {
|
||||
sessionID = ctx.Query("cmux_session_id")
|
||||
}
|
||||
|
||||
command := ctx.Query("command")
|
||||
if command == "" {
|
||||
if sessionID == "" && command == "" {
|
||||
return responder.JSON(http.StatusBadRequest,
|
||||
NewErrorResponse("\"command\" parameter cannot be empty"))
|
||||
}
|
||||
|
|
@ -43,6 +46,20 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
return responder.Code(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if sessionID != "" {
|
||||
return controller.execVMReconnectable(ctx, name, sessionID, command, stdin, wait)
|
||||
}
|
||||
|
||||
return controller.execVMLegacy(ctx, name, command, stdin, wait)
|
||||
}
|
||||
|
||||
func (controller *Controller) execVMLegacy(
|
||||
ctx *gin.Context,
|
||||
name string,
|
||||
command string,
|
||||
stdin bool,
|
||||
wait uint64,
|
||||
) responder.Responder {
|
||||
// Look-up the VM
|
||||
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
||||
defer waitContextCancel()
|
||||
|
|
@ -52,33 +69,27 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
return responderImpl
|
||||
}
|
||||
|
||||
// Establish a port-forwarding connection to a VM's SSH port
|
||||
portForwardConn, err := retry.NewWithData[net.Conn](
|
||||
retry.Context(waitContext),
|
||||
retry.DelayType(retry.FixedDelay),
|
||||
retry.Delay(time.Second),
|
||||
retry.Attempts(0),
|
||||
retry.LastErrorOnly(true),
|
||||
).Do(func() (net.Conn, error) {
|
||||
return controller.portForwardConnection(ctx, waitContext, vm.Worker, vm.UID, 22)
|
||||
})
|
||||
session, err := controller.newSSHExecSession(
|
||||
ctx,
|
||||
waitContext,
|
||||
vm,
|
||||
execSessionKey{vmName: name},
|
||||
command,
|
||||
stdin,
|
||||
nil,
|
||||
legacyExecSessionPolicy,
|
||||
)
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
|
||||
}
|
||||
defer portForwardConn.Close()
|
||||
|
||||
// Establish an SSH connection to a VM
|
||||
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err))
|
||||
}
|
||||
defer exec.Close()
|
||||
|
||||
// Upgrade HTTP request to a WebSocket connection
|
||||
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
})
|
||||
if err != nil {
|
||||
session.closeIfUnused()
|
||||
|
||||
return responder.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
|
|
@ -89,56 +100,177 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
_ = wsConn.CloseNow()
|
||||
}()
|
||||
|
||||
// Read WebSocket frames
|
||||
readFramesErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
readFramesErrCh <- controller.readFrames(ctx, wsConn, exec.Stdin())
|
||||
return controller.serveExecSession(ctx, wsConn, session)
|
||||
}
|
||||
|
||||
func (controller *Controller) execVMReconnectable(
|
||||
ctx *gin.Context,
|
||||
name string,
|
||||
sessionID string,
|
||||
command string,
|
||||
stdin bool,
|
||||
wait uint64,
|
||||
) responder.Responder {
|
||||
key := execSessionKey{
|
||||
vmName: name,
|
||||
sessionID: sessionID,
|
||||
}
|
||||
|
||||
session, ok := controller.execSessions.get(key)
|
||||
if ok {
|
||||
if !session.commandMatches(command) {
|
||||
return responder.JSON(http.StatusConflict,
|
||||
NewErrorResponse("exec session %q is already running a different command", sessionID))
|
||||
}
|
||||
} else {
|
||||
if command == "" {
|
||||
return responder.JSON(http.StatusNotFound,
|
||||
NewErrorResponse("exec session %q does not exist", sessionID))
|
||||
}
|
||||
|
||||
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
||||
defer waitContextCancel()
|
||||
|
||||
vm, responderImpl := controller.waitForVM(waitContext, name)
|
||||
if responderImpl != nil {
|
||||
return responderImpl
|
||||
}
|
||||
|
||||
var err error
|
||||
session, _, err = controller.execSessions.getOrCreate(waitContext, key, func() (*execSession, error) {
|
||||
return controller.newSSHExecSession(
|
||||
ctx,
|
||||
waitContext,
|
||||
vm,
|
||||
key,
|
||||
command,
|
||||
stdin,
|
||||
controller.execSessions,
|
||||
reconnectableExecSessionPolicy,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
|
||||
}
|
||||
|
||||
if !session.commandMatches(command) {
|
||||
return responder.JSON(http.StatusConflict,
|
||||
NewErrorResponse("exec session %q is already running a different command", sessionID))
|
||||
}
|
||||
}
|
||||
|
||||
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
})
|
||||
if err != nil {
|
||||
session.closeIfUnused()
|
||||
|
||||
return responder.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = wsConn.CloseNow()
|
||||
}()
|
||||
|
||||
// Run the command
|
||||
sshErrCh := make(chan error, 1)
|
||||
outgoingFrames := make(chan *execstream.Frame)
|
||||
return controller.serveExecSession(ctx, wsConn, session)
|
||||
}
|
||||
|
||||
func (controller *Controller) newSSHExecSession(
|
||||
_ *gin.Context,
|
||||
waitContext context.Context,
|
||||
vm *v1.VM,
|
||||
key execSessionKey,
|
||||
command string,
|
||||
stdin bool,
|
||||
registry *execSessionRegistry,
|
||||
policy execSessionPolicy,
|
||||
) (*execSession, error) {
|
||||
sessionContext, sessionContextCancel := context.WithCancel(context.Background())
|
||||
|
||||
portForwardConn, err := retry.NewWithData[net.Conn](
|
||||
retry.Context(waitContext),
|
||||
retry.DelayType(retry.FixedDelay),
|
||||
retry.Delay(time.Second),
|
||||
retry.Attempts(0),
|
||||
retry.LastErrorOnly(true),
|
||||
).Do(func() (net.Conn, error) {
|
||||
return controller.portForwardConnection(sessionContext, waitContext, vm.Worker, vm.UID, 22)
|
||||
})
|
||||
if err != nil {
|
||||
sessionContextCancel()
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
|
||||
if err != nil {
|
||||
sessionContextCancel()
|
||||
_ = portForwardConn.Close()
|
||||
|
||||
return nil, fmt.Errorf("failed to establish SSH connection to a VM: %w", err)
|
||||
}
|
||||
|
||||
return newExecSessionWithContext(
|
||||
sessionContext,
|
||||
sessionContextCancel,
|
||||
key,
|
||||
command,
|
||||
exec,
|
||||
portForwardConn,
|
||||
registry,
|
||||
controller.execSessionExitTTL,
|
||||
policy,
|
||||
), nil
|
||||
}
|
||||
|
||||
func (controller *Controller) serveExecSession(
|
||||
ctx *gin.Context,
|
||||
wsConn *websocket.Conn,
|
||||
session *execSession,
|
||||
) responder.Responder {
|
||||
subscriber, err := session.attach()
|
||||
if err != nil {
|
||||
_ = wsConn.Close(websocket.StatusNormalClosure, err.Error())
|
||||
|
||||
return responder.Empty()
|
||||
}
|
||||
defer session.detach(subscriber)
|
||||
session.start()
|
||||
|
||||
readFramesErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
sshErrCh <- exec.Run(ctx, command, outgoingFrames)
|
||||
readFramesErrCh <- controller.readExecSessionFrames(ctx, wsConn, session, subscriber)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case readFramesErr := <-readFramesErrCh:
|
||||
controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr)
|
||||
if readFramesErr != nil &&
|
||||
!errors.Is(readFramesErr, errExecSessionDetached) &&
|
||||
!errors.Is(readFramesErr, errExecSessionClosed) {
|
||||
controller.logger.Warnf("failed to read and process exec frames from WebSocket: %v",
|
||||
readFramesErr)
|
||||
}
|
||||
|
||||
return responder.Empty()
|
||||
case outgoingFrame := <-outgoingFrames:
|
||||
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
|
||||
controller.logger.Warnf("failed to write WebSocket frame to the client: %v", err)
|
||||
case outgoingFrame, ok := <-subscriber.frames:
|
||||
if !ok {
|
||||
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
|
||||
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
|
||||
}
|
||||
|
||||
return responder.Empty()
|
||||
}
|
||||
case sshErr := <-sshErrCh:
|
||||
if sshErr != nil {
|
||||
if err := execstream.WriteFrame(ctx, wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeError,
|
||||
Error: sshErr.Error(),
|
||||
}); err != nil {
|
||||
controller.logger.Warnf("exec: failed to write error frame to WebSocket: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
|
||||
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
|
||||
}
|
||||
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
|
||||
controller.logger.Warnf("failed to write exec frame to the client: %v", err)
|
||||
|
||||
if readFramesErrCh != nil {
|
||||
// Read() on a WebSocket should unblock shortly after calling Close()
|
||||
<-readFramesErrCh
|
||||
return responder.Empty()
|
||||
}
|
||||
|
||||
return responder.Empty()
|
||||
case <-time.After(controller.pingInterval):
|
||||
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
|
||||
if err := wsConn.Ping(pingCtx); err != nil {
|
||||
controller.logger.Warnf("port forwarding: failed to ping the client, "+
|
||||
controller.logger.Warnf("exec: failed to ping the client, "+
|
||||
"connection might time out: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -151,10 +283,16 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
}
|
||||
}
|
||||
|
||||
func (controller *Controller) readFrames(
|
||||
var (
|
||||
errExecSessionDetached = errors.New("exec session detached")
|
||||
errExecSessionClosed = errors.New("exec session closed")
|
||||
)
|
||||
|
||||
func (controller *Controller) readExecSessionFrames(
|
||||
ctx context.Context,
|
||||
wsConn *websocket.Conn,
|
||||
stdinHandle io.WriteCloser,
|
||||
session *execSession,
|
||||
subscriber *execSessionSubscriber,
|
||||
) error {
|
||||
for {
|
||||
var frame execstream.Frame
|
||||
|
|
@ -163,7 +301,7 @@ func (controller *Controller) readFrames(
|
|||
if err != nil {
|
||||
var closeErr websocket.CloseError
|
||||
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
|
||||
return nil
|
||||
return errExecSessionDetached
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
|
||||
|
|
@ -179,26 +317,35 @@ func (controller *Controller) readFrames(
|
|||
|
||||
switch frame.Type {
|
||||
case execstream.FrameTypeStdin:
|
||||
if stdinHandle == nil {
|
||||
return fmt.Errorf("failed to handle %q frame: this exec session "+
|
||||
"has no stdin is enabled or already closed", frame.Type)
|
||||
if err := session.writeStdin(frame.Data); err != nil {
|
||||
return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err)
|
||||
}
|
||||
case execstream.FrameTypeHistory:
|
||||
if !session.policy.replayEnabled {
|
||||
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
||||
}
|
||||
|
||||
if len(frame.Data) == 0 {
|
||||
if err := stdinHandle.Close(); err != nil {
|
||||
return fmt.Errorf("failed to handle %q frame: failed to close "+
|
||||
"stdin: %w", frame.Type, err)
|
||||
}
|
||||
|
||||
stdinHandle = nil
|
||||
|
||||
continue
|
||||
session.sendHistory(subscriber, frame.Watermark)
|
||||
case execstream.FrameTypeAck:
|
||||
if !session.policy.replayEnabled {
|
||||
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
||||
}
|
||||
|
||||
if _, err := stdinHandle.Write(frame.Data); err != nil {
|
||||
return fmt.Errorf("failed to handle %q frame: failed to write "+
|
||||
"to stdin: %w", frame.Type, err)
|
||||
session.ack(frame.Watermark)
|
||||
case execstream.FrameTypeDetach:
|
||||
if !session.policy.replayEnabled {
|
||||
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
||||
}
|
||||
|
||||
return errExecSessionDetached
|
||||
case execstream.FrameTypeClose:
|
||||
if !session.policy.replayEnabled {
|
||||
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
||||
}
|
||||
|
||||
session.close()
|
||||
|
||||
return errExecSessionClosed
|
||||
default:
|
||||
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ type Controller struct {
|
|||
ipRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[string]]
|
||||
enableSwaggerDocs bool
|
||||
workerOfflineTimeout time.Duration
|
||||
execSessionExitTTL time.Duration
|
||||
experimentalRPCV2 bool
|
||||
disableDBCompression bool
|
||||
pingInterval time.Duration
|
||||
|
|
@ -64,6 +65,7 @@ type Controller struct {
|
|||
sshSigner ssh.Signer
|
||||
sshNoClientAuth bool
|
||||
sshServer *sshserver.SSHServer
|
||||
execSessions *execSessionRegistry
|
||||
|
||||
single singleflight.Group
|
||||
|
||||
|
|
@ -75,7 +77,9 @@ func New(opts ...Option) (*Controller, error) {
|
|||
connRendezvous: rendezvous.New[rendezvous.ResultWithErrorMessage[net.Conn]](),
|
||||
ipRendezvous: rendezvous.New[rendezvous.ResultWithErrorMessage[string]](),
|
||||
workerOfflineTimeout: 3 * time.Minute,
|
||||
execSessionExitTTL: 10 * time.Minute,
|
||||
pingInterval: 30 * time.Second,
|
||||
execSessions: newExecSessionRegistry(),
|
||||
single: singleflight.Group{},
|
||||
}
|
||||
|
||||
|
|
@ -308,6 +312,8 @@ func (controller *Controller) Run(ctx context.Context) error {
|
|||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
controller.execSessions.closeAll()
|
||||
|
||||
if err := controller.httpServer.Shutdown(ctx); err != nil {
|
||||
controller.logger.Errorf("failed to cleanly shutdown the HTTP server: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,692 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
)
|
||||
|
||||
const execSessionReplayBufferBytes = 4 * 1024 * 1024
|
||||
|
||||
type execSessionPolicy struct {
|
||||
closeOnDetach bool
|
||||
retainAfterExit bool
|
||||
replayEnabled bool
|
||||
}
|
||||
|
||||
var (
|
||||
legacyExecSessionPolicy = execSessionPolicy{
|
||||
closeOnDetach: true,
|
||||
}
|
||||
reconnectableExecSessionPolicy = execSessionPolicy{
|
||||
retainAfterExit: true,
|
||||
replayEnabled: true,
|
||||
}
|
||||
)
|
||||
|
||||
type sshExecRunner interface {
|
||||
Stdin() io.WriteCloser
|
||||
Run(ctx context.Context, command string, outgoingFrames chan<- *execstream.Frame) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type execSessionKey struct {
|
||||
vmName string
|
||||
sessionID string
|
||||
}
|
||||
|
||||
type execSessionCreation struct {
|
||||
done chan struct{}
|
||||
session *execSession
|
||||
err error
|
||||
}
|
||||
|
||||
type execSessionRegistry struct {
|
||||
mu sync.Mutex
|
||||
sessions map[execSessionKey]*execSession
|
||||
creating map[execSessionKey]*execSessionCreation
|
||||
}
|
||||
|
||||
func newExecSessionRegistry() *execSessionRegistry {
|
||||
return &execSessionRegistry{
|
||||
sessions: map[execSessionKey]*execSession{},
|
||||
creating: map[execSessionKey]*execSessionCreation{},
|
||||
}
|
||||
}
|
||||
|
||||
func (registry *execSessionRegistry) get(key execSessionKey) (*execSession, bool) {
|
||||
registry.mu.Lock()
|
||||
defer registry.mu.Unlock()
|
||||
|
||||
session, ok := registry.sessions[key]
|
||||
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (registry *execSessionRegistry) getOrCreate(
|
||||
ctx context.Context,
|
||||
key execSessionKey,
|
||||
create func() (*execSession, error),
|
||||
) (*execSession, bool, error) {
|
||||
registry.mu.Lock()
|
||||
|
||||
if session, ok := registry.sessions[key]; ok {
|
||||
registry.mu.Unlock()
|
||||
|
||||
return session, false, nil
|
||||
}
|
||||
|
||||
if creation, ok := registry.creating[key]; ok {
|
||||
registry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false, ctx.Err()
|
||||
case <-creation.done:
|
||||
if creation.err != nil {
|
||||
return nil, false, creation.err
|
||||
}
|
||||
|
||||
return creation.session, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
creation := &execSessionCreation{done: make(chan struct{})}
|
||||
registry.creating[key] = creation
|
||||
registry.mu.Unlock()
|
||||
|
||||
session, err := create()
|
||||
|
||||
registry.mu.Lock()
|
||||
delete(registry.creating, key)
|
||||
if err == nil {
|
||||
registry.sessions[key] = session
|
||||
}
|
||||
creation.session = session
|
||||
creation.err = err
|
||||
close(creation.done)
|
||||
registry.mu.Unlock()
|
||||
|
||||
return session, true, err
|
||||
}
|
||||
|
||||
func (registry *execSessionRegistry) remove(key execSessionKey, expected *execSession) {
|
||||
registry.mu.Lock()
|
||||
defer registry.mu.Unlock()
|
||||
|
||||
if registry.sessions[key] == expected {
|
||||
delete(registry.sessions, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (registry *execSessionRegistry) closeAll() {
|
||||
registry.mu.Lock()
|
||||
sessions := make([]*execSession, 0, len(registry.sessions))
|
||||
for _, session := range registry.sessions {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
registry.mu.Unlock()
|
||||
|
||||
for _, session := range sessions {
|
||||
session.close()
|
||||
}
|
||||
}
|
||||
|
||||
type execReplayFrame struct {
|
||||
frame *execstream.Frame
|
||||
size int
|
||||
}
|
||||
|
||||
type execReplayBuffer struct {
|
||||
frames []execReplayFrame
|
||||
bufferBytes int
|
||||
nextWatermark uint64
|
||||
ackedWatermark uint64
|
||||
}
|
||||
|
||||
func (buffer *execReplayBuffer) append(frame *execstream.Frame) *execstream.Frame {
|
||||
frame = cloneExecFrame(frame)
|
||||
buffer.nextWatermark++
|
||||
frame.Watermark = buffer.nextWatermark
|
||||
|
||||
frameSize := execFrameSize(frame)
|
||||
buffer.frames = append(buffer.frames, execReplayFrame{
|
||||
frame: frame,
|
||||
size: frameSize,
|
||||
})
|
||||
buffer.bufferBytes += frameSize
|
||||
buffer.trimAcknowledged()
|
||||
buffer.trimToLimit()
|
||||
|
||||
return frame
|
||||
}
|
||||
|
||||
func (buffer *execReplayBuffer) ack(watermark uint64) {
|
||||
if watermark <= buffer.ackedWatermark {
|
||||
return
|
||||
}
|
||||
|
||||
buffer.ackedWatermark = watermark
|
||||
buffer.trimAcknowledged()
|
||||
}
|
||||
|
||||
func (buffer *execReplayBuffer) replayAfter(
|
||||
watermark uint64,
|
||||
frames []*execstream.Frame,
|
||||
) []*execstream.Frame {
|
||||
for _, record := range buffer.frames {
|
||||
if record.frame.Watermark <= watermark {
|
||||
continue
|
||||
}
|
||||
|
||||
frames = append(frames, record.frame)
|
||||
}
|
||||
|
||||
return frames
|
||||
}
|
||||
|
||||
func (buffer *execReplayBuffer) trimAcknowledged() {
|
||||
for len(buffer.frames) > 0 && buffer.frames[0].frame.Watermark <= buffer.ackedWatermark {
|
||||
buffer.bufferBytes -= buffer.frames[0].size
|
||||
buffer.frames = buffer.frames[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (buffer *execReplayBuffer) trimToLimit() {
|
||||
for buffer.bufferBytes > execSessionReplayBufferBytes && len(buffer.frames) > 0 {
|
||||
buffer.bufferBytes -= buffer.frames[0].size
|
||||
buffer.frames = buffer.frames[1:]
|
||||
}
|
||||
}
|
||||
|
||||
type execSessionSubscriber struct {
|
||||
frames chan *execstream.Frame
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
sendMu sync.Mutex
|
||||
sentWatermark uint64
|
||||
}
|
||||
|
||||
func newExecSessionSubscriber() *execSessionSubscriber {
|
||||
return &execSessionSubscriber{
|
||||
frames: make(chan *execstream.Frame, 128),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (subscriber *execSessionSubscriber) enqueue(frame *execstream.Frame) bool {
|
||||
subscriber.sendMu.Lock()
|
||||
defer subscriber.sendMu.Unlock()
|
||||
|
||||
if subscriber.alreadySentLocked(frame) {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-subscriber.closed:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case subscriber.frames <- subscriber.markSentLocked(frame):
|
||||
return true
|
||||
case <-subscriber.closed:
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (subscriber *execSessionSubscriber) sendHistory(frames []*execstream.Frame) bool {
|
||||
for _, frame := range frames {
|
||||
if !subscriber.sendLocked(frame) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (subscriber *execSessionSubscriber) sendLocked(frame *execstream.Frame) bool {
|
||||
if subscriber.alreadySentLocked(frame) {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-subscriber.closed:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case subscriber.frames <- subscriber.markSentLocked(frame):
|
||||
return true
|
||||
case <-subscriber.closed:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (subscriber *execSessionSubscriber) alreadySentLocked(frame *execstream.Frame) bool {
|
||||
return isReplayOutputFrame(frame) &&
|
||||
frame.Watermark != 0 &&
|
||||
frame.Watermark <= subscriber.sentWatermark
|
||||
}
|
||||
|
||||
func (subscriber *execSessionSubscriber) markSentLocked(frame *execstream.Frame) *execstream.Frame {
|
||||
frame = cloneExecFrame(frame)
|
||||
if isReplayOutputFrame(frame) && frame.Watermark > subscriber.sentWatermark {
|
||||
subscriber.sentWatermark = frame.Watermark
|
||||
}
|
||||
|
||||
return frame
|
||||
}
|
||||
|
||||
func (subscriber *execSessionSubscriber) close() {
|
||||
subscriber.closeOnce.Do(func() {
|
||||
close(subscriber.closed)
|
||||
subscriber.sendMu.Lock()
|
||||
close(subscriber.frames)
|
||||
subscriber.sendMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
type execSession struct {
|
||||
key execSessionKey
|
||||
command string
|
||||
exec sshExecRunner
|
||||
transport net.Conn
|
||||
registry *execSessionRegistry
|
||||
exitTTL time.Duration
|
||||
policy execSessionPolicy
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
mu sync.Mutex
|
||||
stdin io.WriteCloser
|
||||
stdinClosed bool
|
||||
subscribers map[*execSessionSubscriber]struct{}
|
||||
replay execReplayBuffer
|
||||
started bool
|
||||
finished bool
|
||||
closed bool
|
||||
expiryTimer *time.Timer
|
||||
|
||||
startOnce sync.Once
|
||||
done chan struct{}
|
||||
doneOnce sync.Once
|
||||
}
|
||||
|
||||
func newExecSession(
|
||||
key execSessionKey,
|
||||
command string,
|
||||
exec sshExecRunner,
|
||||
transport net.Conn,
|
||||
registry *execSessionRegistry,
|
||||
exitTTL time.Duration,
|
||||
policy execSessionPolicy,
|
||||
) *execSession {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return newExecSessionWithContext(
|
||||
ctx,
|
||||
cancel,
|
||||
key,
|
||||
command,
|
||||
exec,
|
||||
transport,
|
||||
registry,
|
||||
exitTTL,
|
||||
policy,
|
||||
)
|
||||
}
|
||||
|
||||
func newExecSessionWithContext(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
key execSessionKey,
|
||||
command string,
|
||||
exec sshExecRunner,
|
||||
transport net.Conn,
|
||||
registry *execSessionRegistry,
|
||||
exitTTL time.Duration,
|
||||
policy execSessionPolicy,
|
||||
) *execSession {
|
||||
if ctx == nil || cancel == nil {
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
}
|
||||
|
||||
session := &execSession{
|
||||
key: key,
|
||||
command: command,
|
||||
exec: exec,
|
||||
transport: transport,
|
||||
registry: registry,
|
||||
exitTTL: exitTTL,
|
||||
policy: policy,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
stdin: exec.Stdin(),
|
||||
subscribers: map[*execSessionSubscriber]struct{}{},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func (session *execSession) commandMatches(command string) bool {
|
||||
return command == "" || session.command == command
|
||||
}
|
||||
|
||||
func (session *execSession) start() {
|
||||
session.startOnce.Do(func() {
|
||||
session.mu.Lock()
|
||||
if session.closed {
|
||||
session.mu.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
session.started = true
|
||||
session.mu.Unlock()
|
||||
|
||||
go session.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (session *execSession) closeIfUnused() {
|
||||
session.mu.Lock()
|
||||
unused := !session.started && len(session.subscribers) == 0
|
||||
session.mu.Unlock()
|
||||
|
||||
if unused {
|
||||
session.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (session *execSession) attach() (*execSessionSubscriber, error) {
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
|
||||
if session.closed {
|
||||
return nil, errors.New("exec session is closed")
|
||||
}
|
||||
|
||||
subscriber := newExecSessionSubscriber()
|
||||
session.subscribers[subscriber] = struct{}{}
|
||||
|
||||
return subscriber, nil
|
||||
}
|
||||
|
||||
func (session *execSession) detach(subscriber *execSessionSubscriber) {
|
||||
if session.policy.closeOnDetach {
|
||||
session.close()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
|
||||
session.detachLocked(subscriber)
|
||||
}
|
||||
|
||||
func (session *execSession) detachLocked(subscriber *execSessionSubscriber) {
|
||||
if _, ok := session.subscribers[subscriber]; !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(session.subscribers, subscriber)
|
||||
subscriber.close()
|
||||
}
|
||||
|
||||
func (session *execSession) writeStdin(data []byte) error {
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
|
||||
if session.stdin == nil || session.stdinClosed {
|
||||
return errors.New("this exec session has no stdin enabled or it is already closed")
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
if err := session.stdin.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session.stdinClosed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := session.stdin.Write(data)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (session *execSession) ack(watermark uint64) {
|
||||
if !session.policy.replayEnabled {
|
||||
return
|
||||
}
|
||||
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
|
||||
session.replay.ack(watermark)
|
||||
}
|
||||
|
||||
func (session *execSession) sendHistory(
|
||||
subscriber *execSessionSubscriber,
|
||||
watermark uint64,
|
||||
) {
|
||||
if !session.policy.replayEnabled {
|
||||
return
|
||||
}
|
||||
|
||||
session.mu.Lock()
|
||||
|
||||
if _, ok := session.subscribers[subscriber]; !ok {
|
||||
session.mu.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
subscriber.sendMu.Lock()
|
||||
frames := session.replay.replayAfter(watermark, nil)
|
||||
frames = append(frames, &execstream.Frame{
|
||||
Type: execstream.FrameTypeNoMoreHistory,
|
||||
Watermark: session.replay.nextWatermark,
|
||||
})
|
||||
session.mu.Unlock()
|
||||
|
||||
ok := subscriber.sendHistory(frames)
|
||||
subscriber.sendMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
session.dropSubscriber(subscriber)
|
||||
}
|
||||
}
|
||||
|
||||
func (session *execSession) close() {
|
||||
session.mu.Lock()
|
||||
if session.closed {
|
||||
session.mu.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
session.closed = true
|
||||
if session.expiryTimer != nil {
|
||||
session.expiryTimer.Stop()
|
||||
session.expiryTimer = nil
|
||||
}
|
||||
|
||||
subscribers := session.takeSubscribersLocked()
|
||||
session.mu.Unlock()
|
||||
|
||||
closeSubscribers(subscribers)
|
||||
|
||||
session.cancel()
|
||||
_ = session.exec.Close()
|
||||
if session.transport != nil {
|
||||
_ = session.transport.Close()
|
||||
}
|
||||
if session.registry != nil {
|
||||
session.registry.remove(session.key, session)
|
||||
}
|
||||
}
|
||||
|
||||
func (session *execSession) run() {
|
||||
outgoingFrames := make(chan *execstream.Frame)
|
||||
runErrCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
runErrCh <- session.exec.Run(session.ctx, session.command, outgoingFrames)
|
||||
close(outgoingFrames)
|
||||
}()
|
||||
|
||||
for frame := range outgoingFrames {
|
||||
session.recordFrame(frame)
|
||||
}
|
||||
|
||||
runErr := <-runErrCh
|
||||
if runErr != nil && !errors.Is(runErr, context.Canceled) {
|
||||
session.recordFrame(&execstream.Frame{
|
||||
Type: execstream.FrameTypeError,
|
||||
Error: runErr.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
session.markFinished()
|
||||
}
|
||||
|
||||
func (session *execSession) recordFrame(frame *execstream.Frame) {
|
||||
session.mu.Lock()
|
||||
|
||||
if session.closed {
|
||||
session.mu.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if session.policy.replayEnabled {
|
||||
frame = session.replay.append(frame)
|
||||
} else {
|
||||
frame = cloneExecFrame(frame)
|
||||
}
|
||||
|
||||
subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers))
|
||||
for subscriber := range session.subscribers {
|
||||
subscribers = append(subscribers, subscriber)
|
||||
}
|
||||
session.mu.Unlock()
|
||||
|
||||
for _, subscriber := range subscribers {
|
||||
if !subscriber.enqueue(frame) {
|
||||
session.dropSubscriber(subscriber)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (session *execSession) markFinished() {
|
||||
session.mu.Lock()
|
||||
if session.finished {
|
||||
session.mu.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
session.finished = true
|
||||
shouldClose := !session.policy.retainAfterExit
|
||||
if !session.closed && session.policy.retainAfterExit {
|
||||
session.expiryTimer = time.AfterFunc(session.exitTTL, session.expire)
|
||||
}
|
||||
|
||||
var subscribers []*execSessionSubscriber
|
||||
if shouldClose {
|
||||
subscribers = session.takeSubscribersLocked()
|
||||
}
|
||||
session.mu.Unlock()
|
||||
|
||||
closeSubscribers(subscribers)
|
||||
|
||||
session.doneOnce.Do(func() {
|
||||
close(session.done)
|
||||
})
|
||||
|
||||
if shouldClose {
|
||||
session.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (session *execSession) expire() {
|
||||
session.close()
|
||||
}
|
||||
|
||||
func (session *execSession) takeSubscribersLocked() []*execSessionSubscriber {
|
||||
subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers))
|
||||
for subscriber := range session.subscribers {
|
||||
subscribers = append(subscribers, subscriber)
|
||||
}
|
||||
session.subscribers = map[*execSessionSubscriber]struct{}{}
|
||||
|
||||
return subscribers
|
||||
}
|
||||
|
||||
func closeSubscribers(subscribers []*execSessionSubscriber) {
|
||||
for _, subscriber := range subscribers {
|
||||
subscriber.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (session *execSession) dropSubscriber(subscriber *execSessionSubscriber) {
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
|
||||
session.detachLocked(subscriber)
|
||||
}
|
||||
|
||||
func cloneExecFrame(frame *execstream.Frame) *execstream.Frame {
|
||||
if frame == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := *frame
|
||||
if frame.Data != nil {
|
||||
clone.Data = append([]byte(nil), frame.Data...)
|
||||
}
|
||||
if frame.Exit != nil {
|
||||
exit := *frame.Exit
|
||||
clone.Exit = &exit
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
func execFrameSize(frame *execstream.Frame) int {
|
||||
if frame == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(frame.Data) + len(frame.Error) + 16
|
||||
}
|
||||
|
||||
func isReplayOutputFrame(frame *execstream.Frame) bool {
|
||||
if frame == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch frame.Type {
|
||||
case execstream.FrameTypeStdout,
|
||||
execstream.FrameTypeStderr,
|
||||
execstream.FrameTypeExit,
|
||||
execstream.FrameTypeError:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeExec struct {
|
||||
stdin io.WriteCloser
|
||||
run func(context.Context, string, chan<- *execstream.Frame) error
|
||||
closeCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (exec *fakeExec) Stdin() io.WriteCloser {
|
||||
return exec.stdin
|
||||
}
|
||||
|
||||
func (exec *fakeExec) Run(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
outgoingFrames chan<- *execstream.Frame,
|
||||
) error {
|
||||
if exec.run != nil {
|
||||
return exec.run(ctx, command, outgoingFrames)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (exec *fakeExec) Close() error {
|
||||
exec.closeCalls.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newManualExecSessionForTest(
|
||||
key execSessionKey,
|
||||
registry *execSessionRegistry,
|
||||
) *execSession {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &execSession{
|
||||
key: key,
|
||||
command: "echo test",
|
||||
exec: &fakeExec{},
|
||||
registry: registry,
|
||||
exitTTL: time.Minute,
|
||||
policy: reconnectableExecSessionPolicy,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
subscribers: map[*execSessionSubscriber]struct{}{},
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecSessionRegistryGetOrCreateReusesInflightCreation(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
key := execSessionKey{vmName: "vm", sessionID: "session"}
|
||||
|
||||
createStarted := make(chan struct{})
|
||||
releaseCreate := make(chan struct{})
|
||||
var createCalls atomic.Int32
|
||||
|
||||
create := func() (*execSession, error) {
|
||||
createCalls.Add(1)
|
||||
close(createStarted)
|
||||
<-releaseCreate
|
||||
|
||||
return newManualExecSessionForTest(key, registry), nil
|
||||
}
|
||||
|
||||
firstDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(firstDone)
|
||||
_, _, err := registry.getOrCreate(context.Background(), key, create)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
<-createStarted
|
||||
|
||||
secondDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(secondDone)
|
||||
_, created, err := registry.getOrCreate(context.Background(), key, create)
|
||||
require.NoError(t, err)
|
||||
require.False(t, created)
|
||||
}()
|
||||
|
||||
close(releaseCreate)
|
||||
|
||||
<-firstDone
|
||||
<-secondDone
|
||||
require.EqualValues(t, 1, createCalls.Load())
|
||||
}
|
||||
|
||||
func TestExecSessionStartRunsCommandOnlyOnce(t *testing.T) {
|
||||
var runCalls atomic.Int32
|
||||
runStarted := make(chan struct{})
|
||||
|
||||
session := newExecSession(
|
||||
execSessionKey{vmName: "vm", sessionID: "session"},
|
||||
"echo test",
|
||||
&fakeExec{
|
||||
run: func(ctx context.Context, _ string, _ chan<- *execstream.Frame) error {
|
||||
runCalls.Add(1)
|
||||
close(runStarted)
|
||||
<-ctx.Done()
|
||||
|
||||
return ctx.Err()
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
time.Minute,
|
||||
reconnectableExecSessionPolicy,
|
||||
)
|
||||
defer session.close()
|
||||
|
||||
session.start()
|
||||
session.start()
|
||||
|
||||
<-runStarted
|
||||
require.EqualValues(t, 1, runCalls.Load())
|
||||
}
|
||||
|
||||
func TestExecSessionHistoryReplayAndAck(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
|
||||
|
||||
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")})
|
||||
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStderr, Data: []byte("err")})
|
||||
session.recordFrame(&execstream.Frame{
|
||||
Type: execstream.FrameTypeExit,
|
||||
Exit: &execstream.Exit{Code: 7},
|
||||
})
|
||||
|
||||
subscriber, err := session.attach()
|
||||
require.NoError(t, err)
|
||||
|
||||
session.sendHistory(subscriber, 0)
|
||||
|
||||
require.Equal(t, execstream.FrameTypeStdout, (<-subscriber.frames).Type)
|
||||
require.Equal(t, execstream.FrameTypeStderr, (<-subscriber.frames).Type)
|
||||
require.Equal(t, execstream.FrameTypeExit, (<-subscriber.frames).Type)
|
||||
noMoreHistory := <-subscriber.frames
|
||||
require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type)
|
||||
require.EqualValues(t, 3, noMoreHistory.Watermark)
|
||||
|
||||
session.ack(2)
|
||||
require.Len(t, session.replay.frames, 1)
|
||||
require.EqualValues(t, 3, session.replay.frames[0].frame.Watermark)
|
||||
}
|
||||
|
||||
func TestExecSessionHistoryReplayStreamsPastSubscriberBuffer(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
|
||||
|
||||
const frameCount = 256
|
||||
for i := 0; i < frameCount; i++ {
|
||||
session.recordFrame(&execstream.Frame{
|
||||
Type: execstream.FrameTypeStdout,
|
||||
Data: []byte{byte(i)},
|
||||
})
|
||||
}
|
||||
|
||||
subscriber, err := session.attach()
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
session.sendHistory(subscriber, 0)
|
||||
}()
|
||||
|
||||
for i := 1; i <= frameCount; i++ {
|
||||
frame := <-subscriber.frames
|
||||
require.Equal(t, execstream.FrameTypeStdout, frame.Type)
|
||||
require.EqualValues(t, i, frame.Watermark)
|
||||
}
|
||||
|
||||
noMoreHistory := <-subscriber.frames
|
||||
require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type)
|
||||
require.EqualValues(t, frameCount, noMoreHistory.Watermark)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestExecSessionDetachKeepsProcessAlive(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
|
||||
exec := session.exec.(*fakeExec)
|
||||
|
||||
subscriber, err := session.attach()
|
||||
require.NoError(t, err)
|
||||
|
||||
session.detach(subscriber)
|
||||
|
||||
require.False(t, session.closed)
|
||||
require.EqualValues(t, 0, exec.closeCalls.Load())
|
||||
}
|
||||
|
||||
func TestLegacyExecSessionDetachStopsProcess(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
|
||||
session.policy = legacyExecSessionPolicy
|
||||
exec := session.exec.(*fakeExec)
|
||||
|
||||
subscriber, err := session.attach()
|
||||
require.NoError(t, err)
|
||||
|
||||
session.detach(subscriber)
|
||||
|
||||
require.True(t, session.closed)
|
||||
require.EqualValues(t, 1, exec.closeCalls.Load())
|
||||
}
|
||||
|
||||
func TestLegacyExecSessionDoesNotRetainReplayHistory(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
|
||||
session.policy = legacyExecSessionPolicy
|
||||
|
||||
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")})
|
||||
|
||||
require.Empty(t, session.replay.frames)
|
||||
require.Zero(t, session.replay.nextWatermark)
|
||||
}
|
||||
|
||||
func TestExecSessionCloseIfUnusedClosesIdleSession(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
key := execSessionKey{vmName: "vm", sessionID: "session"}
|
||||
session := newManualExecSessionForTest(key, registry)
|
||||
exec := session.exec.(*fakeExec)
|
||||
registry.sessions[key] = session
|
||||
|
||||
session.closeIfUnused()
|
||||
|
||||
require.True(t, session.closed)
|
||||
require.EqualValues(t, 1, exec.closeCalls.Load())
|
||||
}
|
||||
|
||||
func TestExecSessionCloseIfUnusedKeepsAttachedSession(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
|
||||
|
||||
_, err := session.attach()
|
||||
require.NoError(t, err)
|
||||
|
||||
session.closeIfUnused()
|
||||
|
||||
require.False(t, session.closed)
|
||||
}
|
||||
|
||||
func TestExecSessionCloseStopsProcessAndRemovesRegistryEntry(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
key := execSessionKey{vmName: "vm", sessionID: "session"}
|
||||
session := newManualExecSessionForTest(key, registry)
|
||||
exec := session.exec.(*fakeExec)
|
||||
registry.sessions[key] = session
|
||||
|
||||
session.close()
|
||||
|
||||
require.True(t, session.closed)
|
||||
require.EqualValues(t, 1, exec.closeCalls.Load())
|
||||
_, ok := registry.get(key)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestExecSessionFinishedEntryExpiresAfterTTL(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
key := execSessionKey{vmName: "vm", sessionID: "session"}
|
||||
session := newManualExecSessionForTest(key, registry)
|
||||
session.exitTTL = 10 * time.Millisecond
|
||||
registry.sessions[key] = session
|
||||
|
||||
session.markFinished()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, ok := registry.get(key)
|
||||
|
||||
return !ok
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestExecSessionFinishKeepsReconnectableSubscriberOpen(t *testing.T) {
|
||||
registry := newExecSessionRegistry()
|
||||
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
|
||||
|
||||
subscriber, err := session.attach()
|
||||
require.NoError(t, err)
|
||||
|
||||
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")})
|
||||
session.recordFrame(&execstream.Frame{
|
||||
Type: execstream.FrameTypeExit,
|
||||
Exit: &execstream.Exit{Code: 0},
|
||||
})
|
||||
session.markFinished()
|
||||
|
||||
require.Equal(t, execstream.FrameTypeStdout, (<-subscriber.frames).Type)
|
||||
require.Equal(t, execstream.FrameTypeExit, (<-subscriber.frames).Type)
|
||||
|
||||
session.sendHistory(subscriber, 0)
|
||||
|
||||
noMoreHistory, ok := <-subscriber.frames
|
||||
require.True(t, ok)
|
||||
require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type)
|
||||
require.EqualValues(t, 2, noMoreHistory.Watermark)
|
||||
}
|
||||
|
|
@ -60,6 +60,12 @@ func WithWorkerOfflineTimeout(workerOfflineTimeout time.Duration) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithExecSessionExitTTL(execSessionExitTTL time.Duration) Option {
|
||||
return func(controller *Controller) {
|
||||
controller.execSessionExitTTL = execSessionExitTTL
|
||||
}
|
||||
}
|
||||
|
||||
func WithExperimentalRPCV2() Option {
|
||||
return func(controller *Controller) {
|
||||
controller.experimentalRPCV2 = true
|
||||
|
|
|
|||
|
|
@ -14,11 +14,12 @@ import (
|
|||
)
|
||||
|
||||
type Exec struct {
|
||||
sshClient *ssh.Client
|
||||
sshSession *ssh.Session
|
||||
stdout io.Reader
|
||||
stderr io.Reader
|
||||
stdin io.WriteCloser
|
||||
sshClient *ssh.Client
|
||||
sshSession *ssh.Session
|
||||
stdout io.Reader
|
||||
stderr io.Reader
|
||||
stdin io.WriteCloser
|
||||
stdinReader *io.PipeReader
|
||||
}
|
||||
|
||||
func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, error) {
|
||||
|
|
@ -52,14 +53,10 @@ func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, err
|
|||
}
|
||||
|
||||
if stdin {
|
||||
exec.stdin, err = sshSession.StdinPipe()
|
||||
if err != nil {
|
||||
_ = sshSession.Close()
|
||||
_ = sshClient.Close()
|
||||
|
||||
return nil, fmt.Errorf("failed to create standard input pipe "+
|
||||
"for the SSH session: %w", err)
|
||||
}
|
||||
stdinReader, stdinWriter := io.Pipe()
|
||||
sshSession.Stdin = stdinReader
|
||||
exec.stdinReader = stdinReader
|
||||
exec.stdin = stdinWriter
|
||||
}
|
||||
|
||||
exec.stdout, err = sshSession.StdoutPipe()
|
||||
|
|
@ -92,6 +89,12 @@ func (exec *Exec) Run(
|
|||
command string,
|
||||
outgoingFrames chan<- *execstream.Frame,
|
||||
) error {
|
||||
if exec.stdinReader != nil {
|
||||
defer func() {
|
||||
_ = exec.stdinReader.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
if err := exec.sshSession.Start(command); err != nil {
|
||||
return fmt.Errorf("failed to start command %q: %w", command, err)
|
||||
}
|
||||
|
|
@ -188,6 +191,13 @@ func ioStreamReader(
|
|||
}
|
||||
|
||||
func (exec *Exec) Close() error {
|
||||
if exec.stdin != nil {
|
||||
_ = exec.stdin.Close()
|
||||
}
|
||||
if exec.stdinReader != nil {
|
||||
_ = exec.stdinReader.Close()
|
||||
}
|
||||
|
||||
if err := exec.sshSession.Close(); err != nil {
|
||||
_ = exec.sshClient.Close()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,18 +10,24 @@ import (
|
|||
type FrameType string
|
||||
|
||||
const (
|
||||
FrameTypeStdin FrameType = "stdin"
|
||||
FrameTypeStdout FrameType = "stdout"
|
||||
FrameTypeStderr FrameType = "stderr"
|
||||
FrameTypeExit FrameType = "exit"
|
||||
FrameTypeError FrameType = "error"
|
||||
FrameTypeStdin FrameType = "stdin"
|
||||
FrameTypeStdout FrameType = "stdout"
|
||||
FrameTypeStderr FrameType = "stderr"
|
||||
FrameTypeExit FrameType = "exit"
|
||||
FrameTypeError FrameType = "error"
|
||||
FrameTypeHistory FrameType = "history"
|
||||
FrameTypeNoMoreHistory FrameType = "no_more_history"
|
||||
FrameTypeAck FrameType = "ack"
|
||||
FrameTypeDetach FrameType = "detach"
|
||||
FrameTypeClose FrameType = "close"
|
||||
)
|
||||
|
||||
type Frame struct {
|
||||
Type FrameType `json:"type"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
Exit *Exit `json:"exit,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Type FrameType `json:"type"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
Exit *Exit `json:"exit,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Watermark uint64 `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type Exit struct {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
package execstream
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFrameRoundTripsWatermark(t *testing.T) {
|
||||
frame := Frame{
|
||||
Type: FrameTypeHistory,
|
||||
Watermark: 42,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(frame)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded Frame
|
||||
err = json.Unmarshal(payload, &decoded)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, frame, decoded)
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package tests_test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -132,6 +133,176 @@ func TestVMExecScript(t *testing.T) {
|
|||
require.Equal(t, websocket.StatusNormalClosure, closeError.Code)
|
||||
}
|
||||
|
||||
func TestVMExecSessionReconnectHistory(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
|
||||
Command: "sh -c 'echo first; sleep 1; echo second'",
|
||||
WaitSeconds: 30,
|
||||
Session: sessionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstFrame := readFrame(t, wsConn)
|
||||
require.Equal(t, execstream.FrameTypeStdout, firstFrame.Type)
|
||||
require.Equal(t, "first\n", string(firstFrame.Data))
|
||||
require.EqualValues(t, 1, firstFrame.Watermark)
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach})
|
||||
require.NoError(t, err)
|
||||
_ = wsConn.CloseNow()
|
||||
|
||||
// Let the detached process finish so this test verifies partial replay
|
||||
// without relying on live-output timing.
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
|
||||
WaitSeconds: 30,
|
||||
Session: sessionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer wsConn.CloseNow()
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeHistory,
|
||||
Watermark: firstFrame.Watermark,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
frames := readFramesUntilExit(t, wsConn)
|
||||
require.Len(t, framesByType(frames, execstream.FrameTypeStdout), 1)
|
||||
require.Equal(t, "second\n", string(framesByType(frames, execstream.FrameTypeStdout)[0].Data))
|
||||
require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code)
|
||||
}
|
||||
|
||||
func TestVMExecSessionReconnectAfterExit(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
|
||||
Command: "sh -c 'echo replay-me'",
|
||||
WaitSeconds: 30,
|
||||
Session: sessionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach})
|
||||
require.NoError(t, err)
|
||||
_ = wsConn.CloseNow()
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
|
||||
WaitSeconds: 30,
|
||||
Session: sessionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer wsConn.CloseNow()
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory})
|
||||
require.NoError(t, err)
|
||||
|
||||
frames := readFramesUntilExit(t, wsConn)
|
||||
require.Equal(t, "replay-me\n", string(framesByType(frames, execstream.FrameTypeStdout)[0].Data))
|
||||
require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code)
|
||||
}
|
||||
|
||||
func TestVMExecSessionReplayPreservesStreams(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
|
||||
Command: "sh -c 'echo out1; sleep 1; echo err1 >&2; sleep 1; echo out2; sleep 1; echo err2 >&2'",
|
||||
WaitSeconds: 30,
|
||||
Session: sessionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach})
|
||||
require.NoError(t, err)
|
||||
_ = wsConn.CloseNow()
|
||||
|
||||
time.Sleep(4 * time.Second)
|
||||
|
||||
wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
|
||||
WaitSeconds: 30,
|
||||
Session: sessionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer wsConn.CloseNow()
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory})
|
||||
require.NoError(t, err)
|
||||
|
||||
frames := readFramesUntilExit(t, wsConn)
|
||||
require.Equal(t, []execstream.FrameType{
|
||||
execstream.FrameTypeStdout,
|
||||
execstream.FrameTypeStderr,
|
||||
execstream.FrameTypeStdout,
|
||||
execstream.FrameTypeStderr,
|
||||
execstream.FrameTypeExit,
|
||||
}, frameTypes(frames))
|
||||
require.Equal(t, "out1\n", string(frames[0].Data))
|
||||
require.Equal(t, "err1\n", string(frames[1].Data))
|
||||
require.Equal(t, "out2\n", string(frames[2].Data))
|
||||
require.Equal(t, "err2\n", string(frames[3].Data))
|
||||
}
|
||||
|
||||
func TestVMExecSessionStdinSurvivesReconnect(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
|
||||
Command: "/bin/cat",
|
||||
Stdin: true,
|
||||
WaitSeconds: 30,
|
||||
Session: sessionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeStdin,
|
||||
Data: []byte("one\n"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
frame := readFrame(t, wsConn)
|
||||
require.Equal(t, execstream.FrameTypeStdout, frame.Type)
|
||||
require.Equal(t, "one\n", string(frame.Data))
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach})
|
||||
require.NoError(t, err)
|
||||
_ = wsConn.CloseNow()
|
||||
|
||||
wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
|
||||
WaitSeconds: 30,
|
||||
Session: sessionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer wsConn.CloseNow()
|
||||
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeStdin,
|
||||
Data: []byte("two\n"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
|
||||
Type: execstream.FrameTypeStdin,
|
||||
Data: []byte{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory})
|
||||
require.NoError(t, err)
|
||||
|
||||
frames := readFramesUntilExit(t, wsConn)
|
||||
stdoutFrames := framesByType(frames, execstream.FrameTypeStdout)
|
||||
require.Len(t, stdoutFrames, 2)
|
||||
require.Equal(t, "one\n", string(stdoutFrames[0].Data))
|
||||
require.Equal(t, "two\n", string(stdoutFrames[1].Data))
|
||||
require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code)
|
||||
}
|
||||
|
||||
func prepareForExec(t *testing.T) (*client.Client, string) {
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
|
|
@ -153,14 +324,62 @@ func prepareForExec(t *testing.T) (*client.Client, string) {
|
|||
}
|
||||
|
||||
func readFrame(t *testing.T, wsConn *websocket.Conn) *execstream.Frame {
|
||||
t.Helper()
|
||||
|
||||
var frame execstream.Frame
|
||||
|
||||
messageType, payloadBytes, err := wsConn.Read(t.Context())
|
||||
readCtx, readCtxCancel := context.WithTimeout(t.Context(), 30*time.Second)
|
||||
defer readCtxCancel()
|
||||
|
||||
messageType, payloadBytes, err := wsConn.Read(readCtx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, websocket.MessageText, messageType)
|
||||
|
||||
err = json.Unmarshal(payloadBytes, &frame)
|
||||
require.NoError(t, err)
|
||||
if frame.Type == execstream.FrameTypeError {
|
||||
require.FailNowf(t, "exec stream error", "%s", frame.Error)
|
||||
}
|
||||
|
||||
return &frame
|
||||
}
|
||||
|
||||
func readFramesUntilExit(t *testing.T, wsConn *websocket.Conn) []*execstream.Frame {
|
||||
t.Helper()
|
||||
|
||||
var frames []*execstream.Frame
|
||||
|
||||
for {
|
||||
frame := readFrame(t, wsConn)
|
||||
if frame.Type == execstream.FrameTypeNoMoreHistory {
|
||||
continue
|
||||
}
|
||||
|
||||
frames = append(frames, frame)
|
||||
if frame.Type == execstream.FrameTypeExit {
|
||||
return frames
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func framesByType(frames []*execstream.Frame, frameType execstream.FrameType) []*execstream.Frame {
|
||||
var result []*execstream.Frame
|
||||
|
||||
for _, frame := range frames {
|
||||
if frame.Type == frameType {
|
||||
result = append(result, frame)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func frameTypes(frames []*execstream.Frame) []execstream.FrameType {
|
||||
var result []execstream.FrameType
|
||||
|
||||
for _, frame := range frames {
|
||||
result = append(result, frame.Type)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,6 +47,13 @@ type EventsPageOptions struct {
|
|||
Cursor string
|
||||
}
|
||||
|
||||
type ExecSessionOptions struct {
|
||||
Command string
|
||||
Stdin bool
|
||||
WaitSeconds uint16
|
||||
Session string
|
||||
}
|
||||
|
||||
func (service *VMsService) Create(ctx context.Context, vm *v1.VM) error {
|
||||
err := service.client.request(ctx, http.MethodPost, "vms",
|
||||
vm, nil, nil)
|
||||
|
|
@ -164,12 +171,33 @@ func (service *VMsService) Exec(
|
|||
stdin bool,
|
||||
waitSeconds uint16,
|
||||
) (*websocket.Conn, error) {
|
||||
return service.ExecSession(ctx, name, ExecSessionOptions{
|
||||
Command: command,
|
||||
Stdin: stdin,
|
||||
WaitSeconds: waitSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
func (service *VMsService) ExecSession(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
options ExecSessionOptions,
|
||||
) (*websocket.Conn, error) {
|
||||
params := map[string]string{
|
||||
"wait": strconv.FormatUint(uint64(options.WaitSeconds), 10),
|
||||
}
|
||||
if options.Command != "" {
|
||||
params["command"] = options.Command
|
||||
}
|
||||
if options.Stdin {
|
||||
params["stdin"] = strconv.FormatBool(true)
|
||||
}
|
||||
if options.Session != "" {
|
||||
params["session"] = options.Session
|
||||
}
|
||||
|
||||
return service.client.wsRequestRaw(ctx, fmt.Sprintf("vms/%s/exec", url.PathEscape(name)),
|
||||
map[string]string{
|
||||
"command": command,
|
||||
"stdin": strconv.FormatBool(stdin),
|
||||
"wait": strconv.FormatUint(uint64(waitSeconds), 10),
|
||||
})
|
||||
params)
|
||||
}
|
||||
|
||||
func (service *VMsService) IP(ctx context.Context, name string, waitSeconds uint16) (string, error) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecSessionBuildsReconnectableQuery(t *testing.T) {
|
||||
var query map[string][]string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
query = request.URL.Query()
|
||||
|
||||
conn, err := websocket.Accept(writer, request, nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseNow()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
devClient, err := New(WithAddress(server.URL))
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := devClient.VMs().ExecSession(t.Context(), "vm", ExecSessionOptions{
|
||||
Command: "echo hello",
|
||||
Stdin: true,
|
||||
WaitSeconds: 7,
|
||||
Session: "resume-me",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseNow()
|
||||
|
||||
require.Equal(t, []string{"echo hello"}, query["command"])
|
||||
require.Equal(t, []string{"true"}, query["stdin"])
|
||||
require.Equal(t, []string{"7"}, query["wait"])
|
||||
require.Equal(t, []string{"resume-me"}, query["session"])
|
||||
}
|
||||
Loading…
Reference in New Issue