Supporting reconnecting to `/exec` socket (#434)

This commit is contained in:
Fedor Kororkov 2026-05-04 15:49:19 -04:00 committed by GitHub
parent 88506b1adb
commit e6a3314f58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1715 additions and 99 deletions

View File

@ -429,11 +429,33 @@ paths:
parameters:
- in: query
name: command
description: Command to execute
description: |
Command to execute.
Required when starting a new exec session. May be omitted when reconnecting to an
existing session identified by the `session` parameter.
schema:
type: string
minLength: 1
required: true
required: false
- in: query
name: session
description: |
Optional stable exec session identifier. When present, websocket disconnects detach
from the command instead of terminating it, and later requests with the same VM name
and session id may reconnect and request buffered history.
schema:
type: string
minLength: 1
required: false
- in: query
name: cmux_session_id
description: |
Compatibility alias for `session`. Prefer `session` for new Orchard clients.
schema:
type: string
minLength: 1
required: false
- in: query
name: stdin
description: |
@ -483,7 +505,9 @@ paths:
'400':
description: Invalid parameters were supplied
'404':
description: VM resource with the given name doesn't exist
description: VM resource with the given name or reconnectable exec session doesn't exist
'409':
description: Reconnectable exec session already exists with a different command
'503':
description: Controller failed to establish a connection with the VM
/vms/{name}/ip:
@ -800,10 +824,18 @@ components:
description: WebSocket frame from Orchard Client to the Orchard Controller
oneOf:
- $ref: '#/components/schemas/ExecClientFrameStdin'
- $ref: '#/components/schemas/ExecClientFrameHistory'
- $ref: '#/components/schemas/ExecClientFrameAck'
- $ref: '#/components/schemas/ExecClientFrameDetach'
- $ref: '#/components/schemas/ExecClientFrameClose'
discriminator:
propertyName: type
mapping:
stdin: '#/components/schemas/ExecClientFrameStdin'
history: '#/components/schemas/ExecClientFrameHistory'
ack: '#/components/schemas/ExecClientFrameAck'
detach: '#/components/schemas/ExecClientFrameDetach'
close: '#/components/schemas/ExecClientFrameClose'
ExecClientFrameStdin:
description: Send bytes to the process standard input
type: object
@ -822,6 +854,56 @@ components:
example:
type: stdin
data: aGVsbG8K
ExecClientFrameHistory:
description: Request buffered output strictly newer than the supplied watermark
type: object
required: [ type, watermark ]
properties:
type:
type: string
enum: [ history ]
watermark:
type: integer
format: int64
minimum: 0
example:
type: history
watermark: 42
ExecClientFrameAck:
description: Acknowledge that output has been durably consumed through the supplied watermark
type: object
required: [ type, watermark ]
properties:
type:
type: string
enum: [ ack ]
watermark:
type: integer
format: int64
minimum: 0
example:
type: ack
watermark: 42
ExecClientFrameDetach:
description: Detach this websocket while leaving the remote command running
type: object
required: [ type ]
properties:
type:
type: string
enum: [ detach ]
example:
type: detach
ExecClientFrameClose:
description: Close the reconnectable exec session and terminate the remote command
type: object
required: [ type ]
properties:
type:
type: string
enum: [ close ]
example:
type: close
ExecControllerFrame:
description: WebSocket frame from Orchard Controller to the Orchard Client
oneOf:
@ -829,6 +911,7 @@ components:
- $ref: '#/components/schemas/ExecControllerFrameStderr'
- $ref: '#/components/schemas/ExecControllerFrameExit'
- $ref: '#/components/schemas/ExecControllerFrameError'
- $ref: '#/components/schemas/ExecControllerFrameNoMoreHistory'
discriminator:
propertyName: type
mapping:
@ -836,6 +919,7 @@ components:
stderr: '#/components/schemas/ExecControllerFrameStderr'
exit: '#/components/schemas/ExecControllerFrameExit'
error: '#/components/schemas/ExecControllerFrameError'
no_more_history: '#/components/schemas/ExecControllerFrameNoMoreHistory'
ExecControllerFrameStdout:
description: Standard output from the process
type: object
@ -848,6 +932,10 @@ components:
type: string
format: byte
description: Base64-encoded standard output bytes from the process
watermark:
type: integer
format: int64
description: Monotonic output watermark present on reconnectable sessions
example:
type: stdout
data: aGVsbG8K
@ -863,6 +951,10 @@ components:
type: string
format: byte
description: Base64-encoded standard error bytes from the process
watermark:
type: integer
format: int64
description: Monotonic output watermark present on reconnectable sessions
example:
type: stderr
data: aGVsbG8K
@ -882,6 +974,10 @@ components:
type: integer
format: int32
description: Process exit code
watermark:
type: integer
format: int64
description: Monotonic output watermark present on reconnectable sessions
example:
type: exit
exit:
@ -897,9 +993,28 @@ components:
error:
type: string
description: Error message text
watermark:
type: integer
format: int64
description: Monotonic output watermark present on reconnectable sessions
example:
type: error
error: Failed to establish SSH connection to a VM
ExecControllerFrameNoMoreHistory:
description: Marker indicating that the requested replay range has been fully sent
type: object
required: [ type, watermark ]
properties:
type:
type: string
enum: [ no_more_history ]
watermark:
type: integer
format: int64
description: Highest watermark known to the exec session at the time of replay
example:
type: no_more_history
watermark: 42
Event:
title: Generic Resource Event
type: object

View File

@ -35,6 +35,7 @@ var noExperimentalRPCV2 bool
var experimentalPingInterval time.Duration
var experimentalDisableDBCompression bool
var workerOfflineTimeout time.Duration
var execSessionExitTTL time.Duration
var synthetic bool
func newRunCommand() *cobra.Command {
@ -87,6 +88,8 @@ func newRunCommand() *cobra.Command {
"duration (e.g. 60s or 5m30s) after which a worker is considered offline for the purposes "+
"of scheduling (no new VMs will be scheduled on such worker and already assigned VMs will be "+
"marked as failed)")
cmd.Flags().DurationVar(&execSessionExitTTL, "exec-session-exit-ttl", 10*time.Minute,
"duration to retain reconnectable exec session history after the command exits")
// Hidden flags
cmd.Flags().BoolVar(&synthetic, "synthetic", false, "")
@ -147,6 +150,7 @@ func runController(cmd *cobra.Command, args []string) (err error) {
controller.WithListenAddr(address),
controller.WithDataDir(dataDir),
controller.WithWorkerOfflineTimeout(workerOfflineTimeout),
controller.WithExecSessionExitTTL(execSessionExitTTL),
controller.WithLogger(logger),
}

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
@ -28,9 +27,13 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
// Retrieve and parse path and query parameters
name := ctx.Param("name")
sessionID := ctx.Query("session")
if sessionID == "" {
sessionID = ctx.Query("cmux_session_id")
}
command := ctx.Query("command")
if command == "" {
if sessionID == "" && command == "" {
return responder.JSON(http.StatusBadRequest,
NewErrorResponse("\"command\" parameter cannot be empty"))
}
@ -43,6 +46,20 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
return responder.Code(http.StatusBadRequest)
}
if sessionID != "" {
return controller.execVMReconnectable(ctx, name, sessionID, command, stdin, wait)
}
return controller.execVMLegacy(ctx, name, command, stdin, wait)
}
func (controller *Controller) execVMLegacy(
ctx *gin.Context,
name string,
command string,
stdin bool,
wait uint64,
) responder.Responder {
// Look-up the VM
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitContextCancel()
@ -52,33 +69,27 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
return responderImpl
}
// Establish a port-forwarding connection to a VM's SSH port
portForwardConn, err := retry.NewWithData[net.Conn](
retry.Context(waitContext),
retry.DelayType(retry.FixedDelay),
retry.Delay(time.Second),
retry.Attempts(0),
retry.LastErrorOnly(true),
).Do(func() (net.Conn, error) {
return controller.portForwardConnection(ctx, waitContext, vm.Worker, vm.UID, 22)
})
session, err := controller.newSSHExecSession(
ctx,
waitContext,
vm,
execSessionKey{vmName: name},
command,
stdin,
nil,
legacyExecSessionPolicy,
)
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
}
defer portForwardConn.Close()
// Establish an SSH connection to a VM
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err))
}
defer exec.Close()
// Upgrade HTTP request to a WebSocket connection
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
session.closeIfUnused()
return responder.Error(err)
}
defer func() {
@ -89,56 +100,177 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
_ = wsConn.CloseNow()
}()
// Read WebSocket frames
readFramesErrCh := make(chan error, 1)
go func() {
readFramesErrCh <- controller.readFrames(ctx, wsConn, exec.Stdin())
return controller.serveExecSession(ctx, wsConn, session)
}
func (controller *Controller) execVMReconnectable(
ctx *gin.Context,
name string,
sessionID string,
command string,
stdin bool,
wait uint64,
) responder.Responder {
key := execSessionKey{
vmName: name,
sessionID: sessionID,
}
session, ok := controller.execSessions.get(key)
if ok {
if !session.commandMatches(command) {
return responder.JSON(http.StatusConflict,
NewErrorResponse("exec session %q is already running a different command", sessionID))
}
} else {
if command == "" {
return responder.JSON(http.StatusNotFound,
NewErrorResponse("exec session %q does not exist", sessionID))
}
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitContextCancel()
vm, responderImpl := controller.waitForVM(waitContext, name)
if responderImpl != nil {
return responderImpl
}
var err error
session, _, err = controller.execSessions.getOrCreate(waitContext, key, func() (*execSession, error) {
return controller.newSSHExecSession(
ctx,
waitContext,
vm,
key,
command,
stdin,
controller.execSessions,
reconnectableExecSessionPolicy,
)
})
if err != nil {
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err))
}
if !session.commandMatches(command) {
return responder.JSON(http.StatusConflict,
NewErrorResponse("exec session %q is already running a different command", sessionID))
}
}
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
session.closeIfUnused()
return responder.Error(err)
}
defer func() {
_ = wsConn.CloseNow()
}()
// Run the command
sshErrCh := make(chan error, 1)
outgoingFrames := make(chan *execstream.Frame)
return controller.serveExecSession(ctx, wsConn, session)
}
func (controller *Controller) newSSHExecSession(
_ *gin.Context,
waitContext context.Context,
vm *v1.VM,
key execSessionKey,
command string,
stdin bool,
registry *execSessionRegistry,
policy execSessionPolicy,
) (*execSession, error) {
sessionContext, sessionContextCancel := context.WithCancel(context.Background())
portForwardConn, err := retry.NewWithData[net.Conn](
retry.Context(waitContext),
retry.DelayType(retry.FixedDelay),
retry.Delay(time.Second),
retry.Attempts(0),
retry.LastErrorOnly(true),
).Do(func() (net.Conn, error) {
return controller.portForwardConnection(sessionContext, waitContext, vm.Worker, vm.UID, 22)
})
if err != nil {
sessionContextCancel()
return nil, err
}
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
if err != nil {
sessionContextCancel()
_ = portForwardConn.Close()
return nil, fmt.Errorf("failed to establish SSH connection to a VM: %w", err)
}
return newExecSessionWithContext(
sessionContext,
sessionContextCancel,
key,
command,
exec,
portForwardConn,
registry,
controller.execSessionExitTTL,
policy,
), nil
}
func (controller *Controller) serveExecSession(
ctx *gin.Context,
wsConn *websocket.Conn,
session *execSession,
) responder.Responder {
subscriber, err := session.attach()
if err != nil {
_ = wsConn.Close(websocket.StatusNormalClosure, err.Error())
return responder.Empty()
}
defer session.detach(subscriber)
session.start()
readFramesErrCh := make(chan error, 1)
go func() {
sshErrCh <- exec.Run(ctx, command, outgoingFrames)
readFramesErrCh <- controller.readExecSessionFrames(ctx, wsConn, session, subscriber)
}()
for {
select {
case readFramesErr := <-readFramesErrCh:
controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr)
if readFramesErr != nil &&
!errors.Is(readFramesErr, errExecSessionDetached) &&
!errors.Is(readFramesErr, errExecSessionClosed) {
controller.logger.Warnf("failed to read and process exec frames from WebSocket: %v",
readFramesErr)
}
return responder.Empty()
case outgoingFrame := <-outgoingFrames:
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
controller.logger.Warnf("failed to write WebSocket frame to the client: %v", err)
case outgoingFrame, ok := <-subscriber.frames:
if !ok {
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
}
return responder.Empty()
}
case sshErr := <-sshErrCh:
if sshErr != nil {
if err := execstream.WriteFrame(ctx, wsConn, &execstream.Frame{
Type: execstream.FrameTypeError,
Error: sshErr.Error(),
}); err != nil {
controller.logger.Warnf("exec: failed to write error frame to WebSocket: %v", err)
}
}
if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil {
controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err)
}
if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil {
controller.logger.Warnf("failed to write exec frame to the client: %v", err)
if readFramesErrCh != nil {
// Read() on a WebSocket should unblock shortly after calling Close()
<-readFramesErrCh
return responder.Empty()
}
return responder.Empty()
case <-time.After(controller.pingInterval):
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
if err := wsConn.Ping(pingCtx); err != nil {
controller.logger.Warnf("port forwarding: failed to ping the client, "+
controller.logger.Warnf("exec: failed to ping the client, "+
"connection might time out: %v", err)
}
@ -151,10 +283,16 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
}
}
func (controller *Controller) readFrames(
var (
errExecSessionDetached = errors.New("exec session detached")
errExecSessionClosed = errors.New("exec session closed")
)
func (controller *Controller) readExecSessionFrames(
ctx context.Context,
wsConn *websocket.Conn,
stdinHandle io.WriteCloser,
session *execSession,
subscriber *execSessionSubscriber,
) error {
for {
var frame execstream.Frame
@ -163,7 +301,7 @@ func (controller *Controller) readFrames(
if err != nil {
var closeErr websocket.CloseError
if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure {
return nil
return errExecSessionDetached
}
return fmt.Errorf("failed to read next frame from WebSocket: %w", err)
@ -179,26 +317,35 @@ func (controller *Controller) readFrames(
switch frame.Type {
case execstream.FrameTypeStdin:
if stdinHandle == nil {
return fmt.Errorf("failed to handle %q frame: this exec session "+
"has no stdin is enabled or already closed", frame.Type)
if err := session.writeStdin(frame.Data); err != nil {
return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err)
}
case execstream.FrameTypeHistory:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
if len(frame.Data) == 0 {
if err := stdinHandle.Close(); err != nil {
return fmt.Errorf("failed to handle %q frame: failed to close "+
"stdin: %w", frame.Type, err)
}
stdinHandle = nil
continue
session.sendHistory(subscriber, frame.Watermark)
case execstream.FrameTypeAck:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
if _, err := stdinHandle.Write(frame.Data); err != nil {
return fmt.Errorf("failed to handle %q frame: failed to write "+
"to stdin: %w", frame.Type, err)
session.ack(frame.Watermark)
case execstream.FrameTypeDetach:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
return errExecSessionDetached
case execstream.FrameTypeClose:
if !session.policy.replayEnabled {
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}
session.close()
return errExecSessionClosed
default:
return fmt.Errorf("unexpected frame type received: %q", frame.Type)
}

View File

@ -55,6 +55,7 @@ type Controller struct {
ipRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[string]]
enableSwaggerDocs bool
workerOfflineTimeout time.Duration
execSessionExitTTL time.Duration
experimentalRPCV2 bool
disableDBCompression bool
pingInterval time.Duration
@ -64,6 +65,7 @@ type Controller struct {
sshSigner ssh.Signer
sshNoClientAuth bool
sshServer *sshserver.SSHServer
execSessions *execSessionRegistry
single singleflight.Group
@ -75,7 +77,9 @@ func New(opts ...Option) (*Controller, error) {
connRendezvous: rendezvous.New[rendezvous.ResultWithErrorMessage[net.Conn]](),
ipRendezvous: rendezvous.New[rendezvous.ResultWithErrorMessage[string]](),
workerOfflineTimeout: 3 * time.Minute,
execSessionExitTTL: 10 * time.Minute,
pingInterval: 30 * time.Second,
execSessions: newExecSessionRegistry(),
single: singleflight.Group{},
}
@ -308,6 +312,8 @@ func (controller *Controller) Run(ctx context.Context) error {
go func() {
<-ctx.Done()
controller.execSessions.closeAll()
if err := controller.httpServer.Shutdown(ctx); err != nil {
controller.logger.Errorf("failed to cleanly shutdown the HTTP server: %v", err)
}

View File

@ -0,0 +1,692 @@
package controller
import (
"context"
"errors"
"io"
"net"
"sync"
"time"
"github.com/cirruslabs/orchard/internal/execstream"
)
const execSessionReplayBufferBytes = 4 * 1024 * 1024
type execSessionPolicy struct {
closeOnDetach bool
retainAfterExit bool
replayEnabled bool
}
var (
legacyExecSessionPolicy = execSessionPolicy{
closeOnDetach: true,
}
reconnectableExecSessionPolicy = execSessionPolicy{
retainAfterExit: true,
replayEnabled: true,
}
)
type sshExecRunner interface {
Stdin() io.WriteCloser
Run(ctx context.Context, command string, outgoingFrames chan<- *execstream.Frame) error
Close() error
}
type execSessionKey struct {
vmName string
sessionID string
}
type execSessionCreation struct {
done chan struct{}
session *execSession
err error
}
type execSessionRegistry struct {
mu sync.Mutex
sessions map[execSessionKey]*execSession
creating map[execSessionKey]*execSessionCreation
}
func newExecSessionRegistry() *execSessionRegistry {
return &execSessionRegistry{
sessions: map[execSessionKey]*execSession{},
creating: map[execSessionKey]*execSessionCreation{},
}
}
func (registry *execSessionRegistry) get(key execSessionKey) (*execSession, bool) {
registry.mu.Lock()
defer registry.mu.Unlock()
session, ok := registry.sessions[key]
return session, ok
}
func (registry *execSessionRegistry) getOrCreate(
ctx context.Context,
key execSessionKey,
create func() (*execSession, error),
) (*execSession, bool, error) {
registry.mu.Lock()
if session, ok := registry.sessions[key]; ok {
registry.mu.Unlock()
return session, false, nil
}
if creation, ok := registry.creating[key]; ok {
registry.mu.Unlock()
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case <-creation.done:
if creation.err != nil {
return nil, false, creation.err
}
return creation.session, false, nil
}
}
creation := &execSessionCreation{done: make(chan struct{})}
registry.creating[key] = creation
registry.mu.Unlock()
session, err := create()
registry.mu.Lock()
delete(registry.creating, key)
if err == nil {
registry.sessions[key] = session
}
creation.session = session
creation.err = err
close(creation.done)
registry.mu.Unlock()
return session, true, err
}
func (registry *execSessionRegistry) remove(key execSessionKey, expected *execSession) {
registry.mu.Lock()
defer registry.mu.Unlock()
if registry.sessions[key] == expected {
delete(registry.sessions, key)
}
}
func (registry *execSessionRegistry) closeAll() {
registry.mu.Lock()
sessions := make([]*execSession, 0, len(registry.sessions))
for _, session := range registry.sessions {
sessions = append(sessions, session)
}
registry.mu.Unlock()
for _, session := range sessions {
session.close()
}
}
type execReplayFrame struct {
frame *execstream.Frame
size int
}
type execReplayBuffer struct {
frames []execReplayFrame
bufferBytes int
nextWatermark uint64
ackedWatermark uint64
}
func (buffer *execReplayBuffer) append(frame *execstream.Frame) *execstream.Frame {
frame = cloneExecFrame(frame)
buffer.nextWatermark++
frame.Watermark = buffer.nextWatermark
frameSize := execFrameSize(frame)
buffer.frames = append(buffer.frames, execReplayFrame{
frame: frame,
size: frameSize,
})
buffer.bufferBytes += frameSize
buffer.trimAcknowledged()
buffer.trimToLimit()
return frame
}
func (buffer *execReplayBuffer) ack(watermark uint64) {
if watermark <= buffer.ackedWatermark {
return
}
buffer.ackedWatermark = watermark
buffer.trimAcknowledged()
}
func (buffer *execReplayBuffer) replayAfter(
watermark uint64,
frames []*execstream.Frame,
) []*execstream.Frame {
for _, record := range buffer.frames {
if record.frame.Watermark <= watermark {
continue
}
frames = append(frames, record.frame)
}
return frames
}
func (buffer *execReplayBuffer) trimAcknowledged() {
for len(buffer.frames) > 0 && buffer.frames[0].frame.Watermark <= buffer.ackedWatermark {
buffer.bufferBytes -= buffer.frames[0].size
buffer.frames = buffer.frames[1:]
}
}
func (buffer *execReplayBuffer) trimToLimit() {
for buffer.bufferBytes > execSessionReplayBufferBytes && len(buffer.frames) > 0 {
buffer.bufferBytes -= buffer.frames[0].size
buffer.frames = buffer.frames[1:]
}
}
type execSessionSubscriber struct {
frames chan *execstream.Frame
closed chan struct{}
closeOnce sync.Once
sendMu sync.Mutex
sentWatermark uint64
}
func newExecSessionSubscriber() *execSessionSubscriber {
return &execSessionSubscriber{
frames: make(chan *execstream.Frame, 128),
closed: make(chan struct{}),
}
}
func (subscriber *execSessionSubscriber) enqueue(frame *execstream.Frame) bool {
subscriber.sendMu.Lock()
defer subscriber.sendMu.Unlock()
if subscriber.alreadySentLocked(frame) {
return true
}
select {
case <-subscriber.closed:
return false
default:
}
select {
case subscriber.frames <- subscriber.markSentLocked(frame):
return true
case <-subscriber.closed:
return false
default:
return false
}
}
func (subscriber *execSessionSubscriber) sendHistory(frames []*execstream.Frame) bool {
for _, frame := range frames {
if !subscriber.sendLocked(frame) {
return false
}
}
return true
}
func (subscriber *execSessionSubscriber) sendLocked(frame *execstream.Frame) bool {
if subscriber.alreadySentLocked(frame) {
return true
}
select {
case <-subscriber.closed:
return false
default:
}
select {
case subscriber.frames <- subscriber.markSentLocked(frame):
return true
case <-subscriber.closed:
return false
}
}
func (subscriber *execSessionSubscriber) alreadySentLocked(frame *execstream.Frame) bool {
return isReplayOutputFrame(frame) &&
frame.Watermark != 0 &&
frame.Watermark <= subscriber.sentWatermark
}
func (subscriber *execSessionSubscriber) markSentLocked(frame *execstream.Frame) *execstream.Frame {
frame = cloneExecFrame(frame)
if isReplayOutputFrame(frame) && frame.Watermark > subscriber.sentWatermark {
subscriber.sentWatermark = frame.Watermark
}
return frame
}
func (subscriber *execSessionSubscriber) close() {
subscriber.closeOnce.Do(func() {
close(subscriber.closed)
subscriber.sendMu.Lock()
close(subscriber.frames)
subscriber.sendMu.Unlock()
})
}
type execSession struct {
key execSessionKey
command string
exec sshExecRunner
transport net.Conn
registry *execSessionRegistry
exitTTL time.Duration
policy execSessionPolicy
ctx context.Context
cancel context.CancelFunc
mu sync.Mutex
stdin io.WriteCloser
stdinClosed bool
subscribers map[*execSessionSubscriber]struct{}
replay execReplayBuffer
started bool
finished bool
closed bool
expiryTimer *time.Timer
startOnce sync.Once
done chan struct{}
doneOnce sync.Once
}
func newExecSession(
key execSessionKey,
command string,
exec sshExecRunner,
transport net.Conn,
registry *execSessionRegistry,
exitTTL time.Duration,
policy execSessionPolicy,
) *execSession {
ctx, cancel := context.WithCancel(context.Background())
return newExecSessionWithContext(
ctx,
cancel,
key,
command,
exec,
transport,
registry,
exitTTL,
policy,
)
}
func newExecSessionWithContext(
ctx context.Context,
cancel context.CancelFunc,
key execSessionKey,
command string,
exec sshExecRunner,
transport net.Conn,
registry *execSessionRegistry,
exitTTL time.Duration,
policy execSessionPolicy,
) *execSession {
if ctx == nil || cancel == nil {
ctx, cancel = context.WithCancel(context.Background())
}
session := &execSession{
key: key,
command: command,
exec: exec,
transport: transport,
registry: registry,
exitTTL: exitTTL,
policy: policy,
ctx: ctx,
cancel: cancel,
stdin: exec.Stdin(),
subscribers: map[*execSessionSubscriber]struct{}{},
done: make(chan struct{}),
}
return session
}
func (session *execSession) commandMatches(command string) bool {
return command == "" || session.command == command
}
func (session *execSession) start() {
session.startOnce.Do(func() {
session.mu.Lock()
if session.closed {
session.mu.Unlock()
return
}
session.started = true
session.mu.Unlock()
go session.run()
})
}
func (session *execSession) closeIfUnused() {
session.mu.Lock()
unused := !session.started && len(session.subscribers) == 0
session.mu.Unlock()
if unused {
session.close()
}
}
func (session *execSession) attach() (*execSessionSubscriber, error) {
session.mu.Lock()
defer session.mu.Unlock()
if session.closed {
return nil, errors.New("exec session is closed")
}
subscriber := newExecSessionSubscriber()
session.subscribers[subscriber] = struct{}{}
return subscriber, nil
}
func (session *execSession) detach(subscriber *execSessionSubscriber) {
if session.policy.closeOnDetach {
session.close()
return
}
session.mu.Lock()
defer session.mu.Unlock()
session.detachLocked(subscriber)
}
func (session *execSession) detachLocked(subscriber *execSessionSubscriber) {
if _, ok := session.subscribers[subscriber]; !ok {
return
}
delete(session.subscribers, subscriber)
subscriber.close()
}
func (session *execSession) writeStdin(data []byte) error {
session.mu.Lock()
defer session.mu.Unlock()
if session.stdin == nil || session.stdinClosed {
return errors.New("this exec session has no stdin enabled or it is already closed")
}
if len(data) == 0 {
if err := session.stdin.Close(); err != nil {
return err
}
session.stdinClosed = true
return nil
}
_, err := session.stdin.Write(data)
return err
}
func (session *execSession) ack(watermark uint64) {
if !session.policy.replayEnabled {
return
}
session.mu.Lock()
defer session.mu.Unlock()
session.replay.ack(watermark)
}
func (session *execSession) sendHistory(
subscriber *execSessionSubscriber,
watermark uint64,
) {
if !session.policy.replayEnabled {
return
}
session.mu.Lock()
if _, ok := session.subscribers[subscriber]; !ok {
session.mu.Unlock()
return
}
subscriber.sendMu.Lock()
frames := session.replay.replayAfter(watermark, nil)
frames = append(frames, &execstream.Frame{
Type: execstream.FrameTypeNoMoreHistory,
Watermark: session.replay.nextWatermark,
})
session.mu.Unlock()
ok := subscriber.sendHistory(frames)
subscriber.sendMu.Unlock()
if !ok {
session.dropSubscriber(subscriber)
}
}
func (session *execSession) close() {
session.mu.Lock()
if session.closed {
session.mu.Unlock()
return
}
session.closed = true
if session.expiryTimer != nil {
session.expiryTimer.Stop()
session.expiryTimer = nil
}
subscribers := session.takeSubscribersLocked()
session.mu.Unlock()
closeSubscribers(subscribers)
session.cancel()
_ = session.exec.Close()
if session.transport != nil {
_ = session.transport.Close()
}
if session.registry != nil {
session.registry.remove(session.key, session)
}
}
func (session *execSession) run() {
outgoingFrames := make(chan *execstream.Frame)
runErrCh := make(chan error, 1)
go func() {
runErrCh <- session.exec.Run(session.ctx, session.command, outgoingFrames)
close(outgoingFrames)
}()
for frame := range outgoingFrames {
session.recordFrame(frame)
}
runErr := <-runErrCh
if runErr != nil && !errors.Is(runErr, context.Canceled) {
session.recordFrame(&execstream.Frame{
Type: execstream.FrameTypeError,
Error: runErr.Error(),
})
}
session.markFinished()
}
func (session *execSession) recordFrame(frame *execstream.Frame) {
session.mu.Lock()
if session.closed {
session.mu.Unlock()
return
}
if session.policy.replayEnabled {
frame = session.replay.append(frame)
} else {
frame = cloneExecFrame(frame)
}
subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers))
for subscriber := range session.subscribers {
subscribers = append(subscribers, subscriber)
}
session.mu.Unlock()
for _, subscriber := range subscribers {
if !subscriber.enqueue(frame) {
session.dropSubscriber(subscriber)
}
}
}
func (session *execSession) markFinished() {
session.mu.Lock()
if session.finished {
session.mu.Unlock()
return
}
session.finished = true
shouldClose := !session.policy.retainAfterExit
if !session.closed && session.policy.retainAfterExit {
session.expiryTimer = time.AfterFunc(session.exitTTL, session.expire)
}
var subscribers []*execSessionSubscriber
if shouldClose {
subscribers = session.takeSubscribersLocked()
}
session.mu.Unlock()
closeSubscribers(subscribers)
session.doneOnce.Do(func() {
close(session.done)
})
if shouldClose {
session.close()
}
}
func (session *execSession) expire() {
session.close()
}
func (session *execSession) takeSubscribersLocked() []*execSessionSubscriber {
subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers))
for subscriber := range session.subscribers {
subscribers = append(subscribers, subscriber)
}
session.subscribers = map[*execSessionSubscriber]struct{}{}
return subscribers
}
func closeSubscribers(subscribers []*execSessionSubscriber) {
for _, subscriber := range subscribers {
subscriber.close()
}
}
func (session *execSession) dropSubscriber(subscriber *execSessionSubscriber) {
session.mu.Lock()
defer session.mu.Unlock()
session.detachLocked(subscriber)
}
func cloneExecFrame(frame *execstream.Frame) *execstream.Frame {
if frame == nil {
return nil
}
clone := *frame
if frame.Data != nil {
clone.Data = append([]byte(nil), frame.Data...)
}
if frame.Exit != nil {
exit := *frame.Exit
clone.Exit = &exit
}
return &clone
}
func execFrameSize(frame *execstream.Frame) int {
if frame == nil {
return 0
}
return len(frame.Data) + len(frame.Error) + 16
}
func isReplayOutputFrame(frame *execstream.Frame) bool {
if frame == nil {
return false
}
switch frame.Type {
case execstream.FrameTypeStdout,
execstream.FrameTypeStderr,
execstream.FrameTypeExit,
execstream.FrameTypeError:
return true
default:
return false
}
}

View File

@ -0,0 +1,320 @@
package controller
import (
"context"
"io"
"sync/atomic"
"testing"
"time"
"github.com/cirruslabs/orchard/internal/execstream"
"github.com/stretchr/testify/require"
)
type fakeExec struct {
stdin io.WriteCloser
run func(context.Context, string, chan<- *execstream.Frame) error
closeCalls atomic.Int32
}
func (exec *fakeExec) Stdin() io.WriteCloser {
return exec.stdin
}
func (exec *fakeExec) Run(
ctx context.Context,
command string,
outgoingFrames chan<- *execstream.Frame,
) error {
if exec.run != nil {
return exec.run(ctx, command, outgoingFrames)
}
return nil
}
func (exec *fakeExec) Close() error {
exec.closeCalls.Add(1)
return nil
}
func newManualExecSessionForTest(
key execSessionKey,
registry *execSessionRegistry,
) *execSession {
ctx, cancel := context.WithCancel(context.Background())
return &execSession{
key: key,
command: "echo test",
exec: &fakeExec{},
registry: registry,
exitTTL: time.Minute,
policy: reconnectableExecSessionPolicy,
ctx: ctx,
cancel: cancel,
subscribers: map[*execSessionSubscriber]struct{}{},
done: make(chan struct{}),
}
}
func TestExecSessionRegistryGetOrCreateReusesInflightCreation(t *testing.T) {
registry := newExecSessionRegistry()
key := execSessionKey{vmName: "vm", sessionID: "session"}
createStarted := make(chan struct{})
releaseCreate := make(chan struct{})
var createCalls atomic.Int32
create := func() (*execSession, error) {
createCalls.Add(1)
close(createStarted)
<-releaseCreate
return newManualExecSessionForTest(key, registry), nil
}
firstDone := make(chan struct{})
go func() {
defer close(firstDone)
_, _, err := registry.getOrCreate(context.Background(), key, create)
require.NoError(t, err)
}()
<-createStarted
secondDone := make(chan struct{})
go func() {
defer close(secondDone)
_, created, err := registry.getOrCreate(context.Background(), key, create)
require.NoError(t, err)
require.False(t, created)
}()
close(releaseCreate)
<-firstDone
<-secondDone
require.EqualValues(t, 1, createCalls.Load())
}
func TestExecSessionStartRunsCommandOnlyOnce(t *testing.T) {
var runCalls atomic.Int32
runStarted := make(chan struct{})
session := newExecSession(
execSessionKey{vmName: "vm", sessionID: "session"},
"echo test",
&fakeExec{
run: func(ctx context.Context, _ string, _ chan<- *execstream.Frame) error {
runCalls.Add(1)
close(runStarted)
<-ctx.Done()
return ctx.Err()
},
},
nil,
nil,
time.Minute,
reconnectableExecSessionPolicy,
)
defer session.close()
session.start()
session.start()
<-runStarted
require.EqualValues(t, 1, runCalls.Load())
}
func TestExecSessionHistoryReplayAndAck(t *testing.T) {
registry := newExecSessionRegistry()
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")})
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStderr, Data: []byte("err")})
session.recordFrame(&execstream.Frame{
Type: execstream.FrameTypeExit,
Exit: &execstream.Exit{Code: 7},
})
subscriber, err := session.attach()
require.NoError(t, err)
session.sendHistory(subscriber, 0)
require.Equal(t, execstream.FrameTypeStdout, (<-subscriber.frames).Type)
require.Equal(t, execstream.FrameTypeStderr, (<-subscriber.frames).Type)
require.Equal(t, execstream.FrameTypeExit, (<-subscriber.frames).Type)
noMoreHistory := <-subscriber.frames
require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type)
require.EqualValues(t, 3, noMoreHistory.Watermark)
session.ack(2)
require.Len(t, session.replay.frames, 1)
require.EqualValues(t, 3, session.replay.frames[0].frame.Watermark)
}
func TestExecSessionHistoryReplayStreamsPastSubscriberBuffer(t *testing.T) {
registry := newExecSessionRegistry()
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
const frameCount = 256
for i := 0; i < frameCount; i++ {
session.recordFrame(&execstream.Frame{
Type: execstream.FrameTypeStdout,
Data: []byte{byte(i)},
})
}
subscriber, err := session.attach()
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
session.sendHistory(subscriber, 0)
}()
for i := 1; i <= frameCount; i++ {
frame := <-subscriber.frames
require.Equal(t, execstream.FrameTypeStdout, frame.Type)
require.EqualValues(t, i, frame.Watermark)
}
noMoreHistory := <-subscriber.frames
require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type)
require.EqualValues(t, frameCount, noMoreHistory.Watermark)
require.Eventually(t, func() bool {
select {
case <-done:
return true
default:
return false
}
}, time.Second, 10*time.Millisecond)
}
func TestExecSessionDetachKeepsProcessAlive(t *testing.T) {
registry := newExecSessionRegistry()
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
exec := session.exec.(*fakeExec)
subscriber, err := session.attach()
require.NoError(t, err)
session.detach(subscriber)
require.False(t, session.closed)
require.EqualValues(t, 0, exec.closeCalls.Load())
}
func TestLegacyExecSessionDetachStopsProcess(t *testing.T) {
registry := newExecSessionRegistry()
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
session.policy = legacyExecSessionPolicy
exec := session.exec.(*fakeExec)
subscriber, err := session.attach()
require.NoError(t, err)
session.detach(subscriber)
require.True(t, session.closed)
require.EqualValues(t, 1, exec.closeCalls.Load())
}
func TestLegacyExecSessionDoesNotRetainReplayHistory(t *testing.T) {
registry := newExecSessionRegistry()
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
session.policy = legacyExecSessionPolicy
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")})
require.Empty(t, session.replay.frames)
require.Zero(t, session.replay.nextWatermark)
}
func TestExecSessionCloseIfUnusedClosesIdleSession(t *testing.T) {
registry := newExecSessionRegistry()
key := execSessionKey{vmName: "vm", sessionID: "session"}
session := newManualExecSessionForTest(key, registry)
exec := session.exec.(*fakeExec)
registry.sessions[key] = session
session.closeIfUnused()
require.True(t, session.closed)
require.EqualValues(t, 1, exec.closeCalls.Load())
}
func TestExecSessionCloseIfUnusedKeepsAttachedSession(t *testing.T) {
registry := newExecSessionRegistry()
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
_, err := session.attach()
require.NoError(t, err)
session.closeIfUnused()
require.False(t, session.closed)
}
func TestExecSessionCloseStopsProcessAndRemovesRegistryEntry(t *testing.T) {
registry := newExecSessionRegistry()
key := execSessionKey{vmName: "vm", sessionID: "session"}
session := newManualExecSessionForTest(key, registry)
exec := session.exec.(*fakeExec)
registry.sessions[key] = session
session.close()
require.True(t, session.closed)
require.EqualValues(t, 1, exec.closeCalls.Load())
_, ok := registry.get(key)
require.False(t, ok)
}
func TestExecSessionFinishedEntryExpiresAfterTTL(t *testing.T) {
registry := newExecSessionRegistry()
key := execSessionKey{vmName: "vm", sessionID: "session"}
session := newManualExecSessionForTest(key, registry)
session.exitTTL = 10 * time.Millisecond
registry.sessions[key] = session
session.markFinished()
require.Eventually(t, func() bool {
_, ok := registry.get(key)
return !ok
}, time.Second, 10*time.Millisecond)
}
func TestExecSessionFinishKeepsReconnectableSubscriberOpen(t *testing.T) {
registry := newExecSessionRegistry()
session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry)
subscriber, err := session.attach()
require.NoError(t, err)
session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")})
session.recordFrame(&execstream.Frame{
Type: execstream.FrameTypeExit,
Exit: &execstream.Exit{Code: 0},
})
session.markFinished()
require.Equal(t, execstream.FrameTypeStdout, (<-subscriber.frames).Type)
require.Equal(t, execstream.FrameTypeExit, (<-subscriber.frames).Type)
session.sendHistory(subscriber, 0)
noMoreHistory, ok := <-subscriber.frames
require.True(t, ok)
require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type)
require.EqualValues(t, 2, noMoreHistory.Watermark)
}

View File

@ -60,6 +60,12 @@ func WithWorkerOfflineTimeout(workerOfflineTimeout time.Duration) Option {
}
}
func WithExecSessionExitTTL(execSessionExitTTL time.Duration) Option {
return func(controller *Controller) {
controller.execSessionExitTTL = execSessionExitTTL
}
}
func WithExperimentalRPCV2() Option {
return func(controller *Controller) {
controller.experimentalRPCV2 = true

View File

@ -14,11 +14,12 @@ import (
)
type Exec struct {
sshClient *ssh.Client
sshSession *ssh.Session
stdout io.Reader
stderr io.Reader
stdin io.WriteCloser
sshClient *ssh.Client
sshSession *ssh.Session
stdout io.Reader
stderr io.Reader
stdin io.WriteCloser
stdinReader *io.PipeReader
}
func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, error) {
@ -52,14 +53,10 @@ func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, err
}
if stdin {
exec.stdin, err = sshSession.StdinPipe()
if err != nil {
_ = sshSession.Close()
_ = sshClient.Close()
return nil, fmt.Errorf("failed to create standard input pipe "+
"for the SSH session: %w", err)
}
stdinReader, stdinWriter := io.Pipe()
sshSession.Stdin = stdinReader
exec.stdinReader = stdinReader
exec.stdin = stdinWriter
}
exec.stdout, err = sshSession.StdoutPipe()
@ -92,6 +89,12 @@ func (exec *Exec) Run(
command string,
outgoingFrames chan<- *execstream.Frame,
) error {
if exec.stdinReader != nil {
defer func() {
_ = exec.stdinReader.Close()
}()
}
if err := exec.sshSession.Start(command); err != nil {
return fmt.Errorf("failed to start command %q: %w", command, err)
}
@ -188,6 +191,13 @@ func ioStreamReader(
}
func (exec *Exec) Close() error {
if exec.stdin != nil {
_ = exec.stdin.Close()
}
if exec.stdinReader != nil {
_ = exec.stdinReader.Close()
}
if err := exec.sshSession.Close(); err != nil {
_ = exec.sshClient.Close()

View File

@ -10,18 +10,24 @@ import (
type FrameType string
const (
FrameTypeStdin FrameType = "stdin"
FrameTypeStdout FrameType = "stdout"
FrameTypeStderr FrameType = "stderr"
FrameTypeExit FrameType = "exit"
FrameTypeError FrameType = "error"
FrameTypeStdin FrameType = "stdin"
FrameTypeStdout FrameType = "stdout"
FrameTypeStderr FrameType = "stderr"
FrameTypeExit FrameType = "exit"
FrameTypeError FrameType = "error"
FrameTypeHistory FrameType = "history"
FrameTypeNoMoreHistory FrameType = "no_more_history"
FrameTypeAck FrameType = "ack"
FrameTypeDetach FrameType = "detach"
FrameTypeClose FrameType = "close"
)
type Frame struct {
Type FrameType `json:"type"`
Data []byte `json:"data,omitempty"`
Exit *Exit `json:"exit,omitempty"`
Error string `json:"error,omitempty"`
Type FrameType `json:"type"`
Data []byte `json:"data,omitempty"`
Exit *Exit `json:"exit,omitempty"`
Error string `json:"error,omitempty"`
Watermark uint64 `json:"watermark,omitempty"`
}
type Exit struct {

View File

@ -0,0 +1,23 @@
package execstream
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestFrameRoundTripsWatermark(t *testing.T) {
frame := Frame{
Type: FrameTypeHistory,
Watermark: 42,
}
payload, err := json.Marshal(frame)
require.NoError(t, err)
var decoded Frame
err = json.Unmarshal(payload, &decoded)
require.NoError(t, err)
require.Equal(t, frame, decoded)
}

View File

@ -2,6 +2,7 @@ package tests_test
import (
"bytes"
"context"
"encoding/json"
"testing"
"time"
@ -132,6 +133,176 @@ func TestVMExecScript(t *testing.T) {
require.Equal(t, websocket.StatusNormalClosure, closeError.Code)
}
func TestVMExecSessionReconnectHistory(t *testing.T) {
devClient, vmName := prepareForExec(t)
sessionID := uuid.NewString()
wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
Command: "sh -c 'echo first; sleep 1; echo second'",
WaitSeconds: 30,
Session: sessionID,
})
require.NoError(t, err)
firstFrame := readFrame(t, wsConn)
require.Equal(t, execstream.FrameTypeStdout, firstFrame.Type)
require.Equal(t, "first\n", string(firstFrame.Data))
require.EqualValues(t, 1, firstFrame.Watermark)
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach})
require.NoError(t, err)
_ = wsConn.CloseNow()
// Let the detached process finish so this test verifies partial replay
// without relying on live-output timing.
time.Sleep(2 * time.Second)
wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
WaitSeconds: 30,
Session: sessionID,
})
require.NoError(t, err)
defer wsConn.CloseNow()
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
Type: execstream.FrameTypeHistory,
Watermark: firstFrame.Watermark,
})
require.NoError(t, err)
frames := readFramesUntilExit(t, wsConn)
require.Len(t, framesByType(frames, execstream.FrameTypeStdout), 1)
require.Equal(t, "second\n", string(framesByType(frames, execstream.FrameTypeStdout)[0].Data))
require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code)
}
func TestVMExecSessionReconnectAfterExit(t *testing.T) {
devClient, vmName := prepareForExec(t)
sessionID := uuid.NewString()
wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
Command: "sh -c 'echo replay-me'",
WaitSeconds: 30,
Session: sessionID,
})
require.NoError(t, err)
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach})
require.NoError(t, err)
_ = wsConn.CloseNow()
time.Sleep(time.Second)
wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
WaitSeconds: 30,
Session: sessionID,
})
require.NoError(t, err)
defer wsConn.CloseNow()
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory})
require.NoError(t, err)
frames := readFramesUntilExit(t, wsConn)
require.Equal(t, "replay-me\n", string(framesByType(frames, execstream.FrameTypeStdout)[0].Data))
require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code)
}
func TestVMExecSessionReplayPreservesStreams(t *testing.T) {
devClient, vmName := prepareForExec(t)
sessionID := uuid.NewString()
wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
Command: "sh -c 'echo out1; sleep 1; echo err1 >&2; sleep 1; echo out2; sleep 1; echo err2 >&2'",
WaitSeconds: 30,
Session: sessionID,
})
require.NoError(t, err)
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach})
require.NoError(t, err)
_ = wsConn.CloseNow()
time.Sleep(4 * time.Second)
wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
WaitSeconds: 30,
Session: sessionID,
})
require.NoError(t, err)
defer wsConn.CloseNow()
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory})
require.NoError(t, err)
frames := readFramesUntilExit(t, wsConn)
require.Equal(t, []execstream.FrameType{
execstream.FrameTypeStdout,
execstream.FrameTypeStderr,
execstream.FrameTypeStdout,
execstream.FrameTypeStderr,
execstream.FrameTypeExit,
}, frameTypes(frames))
require.Equal(t, "out1\n", string(frames[0].Data))
require.Equal(t, "err1\n", string(frames[1].Data))
require.Equal(t, "out2\n", string(frames[2].Data))
require.Equal(t, "err2\n", string(frames[3].Data))
}
func TestVMExecSessionStdinSurvivesReconnect(t *testing.T) {
devClient, vmName := prepareForExec(t)
sessionID := uuid.NewString()
wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
Command: "/bin/cat",
Stdin: true,
WaitSeconds: 30,
Session: sessionID,
})
require.NoError(t, err)
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
Type: execstream.FrameTypeStdin,
Data: []byte("one\n"),
})
require.NoError(t, err)
frame := readFrame(t, wsConn)
require.Equal(t, execstream.FrameTypeStdout, frame.Type)
require.Equal(t, "one\n", string(frame.Data))
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach})
require.NoError(t, err)
_ = wsConn.CloseNow()
wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{
WaitSeconds: 30,
Session: sessionID,
})
require.NoError(t, err)
defer wsConn.CloseNow()
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
Type: execstream.FrameTypeStdin,
Data: []byte("two\n"),
})
require.NoError(t, err)
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{
Type: execstream.FrameTypeStdin,
Data: []byte{},
})
require.NoError(t, err)
err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory})
require.NoError(t, err)
frames := readFramesUntilExit(t, wsConn)
stdoutFrames := framesByType(frames, execstream.FrameTypeStdout)
require.Len(t, stdoutFrames, 2)
require.Equal(t, "one\n", string(stdoutFrames[0].Data))
require.Equal(t, "two\n", string(stdoutFrames[1].Data))
require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code)
}
func prepareForExec(t *testing.T) (*client.Client, string) {
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
@ -153,14 +324,62 @@ func prepareForExec(t *testing.T) (*client.Client, string) {
}
func readFrame(t *testing.T, wsConn *websocket.Conn) *execstream.Frame {
t.Helper()
var frame execstream.Frame
messageType, payloadBytes, err := wsConn.Read(t.Context())
readCtx, readCtxCancel := context.WithTimeout(t.Context(), 30*time.Second)
defer readCtxCancel()
messageType, payloadBytes, err := wsConn.Read(readCtx)
require.NoError(t, err)
require.Equal(t, websocket.MessageText, messageType)
err = json.Unmarshal(payloadBytes, &frame)
require.NoError(t, err)
if frame.Type == execstream.FrameTypeError {
require.FailNowf(t, "exec stream error", "%s", frame.Error)
}
return &frame
}
func readFramesUntilExit(t *testing.T, wsConn *websocket.Conn) []*execstream.Frame {
t.Helper()
var frames []*execstream.Frame
for {
frame := readFrame(t, wsConn)
if frame.Type == execstream.FrameTypeNoMoreHistory {
continue
}
frames = append(frames, frame)
if frame.Type == execstream.FrameTypeExit {
return frames
}
}
}
func framesByType(frames []*execstream.Frame, frameType execstream.FrameType) []*execstream.Frame {
var result []*execstream.Frame
for _, frame := range frames {
if frame.Type == frameType {
result = append(result, frame)
}
}
return result
}
func frameTypes(frames []*execstream.Frame) []execstream.FrameType {
var result []execstream.FrameType
for _, frame := range frames {
result = append(result, frame.Type)
}
return result
}

View File

@ -47,6 +47,13 @@ type EventsPageOptions struct {
Cursor string
}
type ExecSessionOptions struct {
Command string
Stdin bool
WaitSeconds uint16
Session string
}
func (service *VMsService) Create(ctx context.Context, vm *v1.VM) error {
err := service.client.request(ctx, http.MethodPost, "vms",
vm, nil, nil)
@ -164,12 +171,33 @@ func (service *VMsService) Exec(
stdin bool,
waitSeconds uint16,
) (*websocket.Conn, error) {
return service.ExecSession(ctx, name, ExecSessionOptions{
Command: command,
Stdin: stdin,
WaitSeconds: waitSeconds,
})
}
func (service *VMsService) ExecSession(
ctx context.Context,
name string,
options ExecSessionOptions,
) (*websocket.Conn, error) {
params := map[string]string{
"wait": strconv.FormatUint(uint64(options.WaitSeconds), 10),
}
if options.Command != "" {
params["command"] = options.Command
}
if options.Stdin {
params["stdin"] = strconv.FormatBool(true)
}
if options.Session != "" {
params["session"] = options.Session
}
return service.client.wsRequestRaw(ctx, fmt.Sprintf("vms/%s/exec", url.PathEscape(name)),
map[string]string{
"command": command,
"stdin": strconv.FormatBool(stdin),
"wait": strconv.FormatUint(uint64(waitSeconds), 10),
})
params)
}
func (service *VMsService) IP(ctx context.Context, name string, waitSeconds uint16) (string, error) {

40
pkg/client/vms_test.go Normal file
View File

@ -0,0 +1,40 @@
package client
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/coder/websocket"
"github.com/stretchr/testify/require"
)
func TestExecSessionBuildsReconnectableQuery(t *testing.T) {
var query map[string][]string
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
query = request.URL.Query()
conn, err := websocket.Accept(writer, request, nil)
require.NoError(t, err)
defer conn.CloseNow()
}))
defer server.Close()
devClient, err := New(WithAddress(server.URL))
require.NoError(t, err)
conn, err := devClient.VMs().ExecSession(t.Context(), "vm", ExecSessionOptions{
Command: "echo hello",
Stdin: true,
WaitSeconds: 7,
Session: "resume-me",
})
require.NoError(t, err)
defer conn.CloseNow()
require.Equal(t, []string{"echo hello"}, query["command"])
require.Equal(t, []string{"true"}, query["stdin"])
require.Equal(t, []string{"7"}, query["wait"])
require.Equal(t, []string{"resume-me"}, query["session"])
}