From e6a3314f58a846ff1baceaff4e62721affb318ee Mon Sep 17 00:00:00 2001 From: Fedor Kororkov Date: Mon, 4 May 2026 15:49:19 -0400 Subject: [PATCH] Supporting reconnecting to `/exec` socket (#434) --- api/openapi.yaml | 121 +++- internal/command/controller/run.go | 4 + internal/controller/api_vms_exec.go | 283 ++++++--- internal/controller/controller.go | 6 + internal/controller/exec_sessions.go | 692 ++++++++++++++++++++++ internal/controller/exec_sessions_test.go | 320 ++++++++++ internal/controller/option.go | 6 + internal/controller/sshexec/sshexec.go | 36 +- internal/execstream/frame.go | 24 +- internal/execstream/frame_test.go | 23 + internal/tests/exec_test.go | 221 ++++++- pkg/client/vms.go | 38 +- pkg/client/vms_test.go | 40 ++ 13 files changed, 1715 insertions(+), 99 deletions(-) create mode 100644 internal/controller/exec_sessions.go create mode 100644 internal/controller/exec_sessions_test.go create mode 100644 internal/execstream/frame_test.go create mode 100644 pkg/client/vms_test.go diff --git a/api/openapi.yaml b/api/openapi.yaml index 29823a2..10f4d89 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -429,11 +429,33 @@ paths: parameters: - in: query name: command - description: Command to execute + description: | + Command to execute. + + Required when starting a new exec session. May be omitted when reconnecting to an + existing session identified by the `session` parameter. schema: type: string minLength: 1 - required: true + required: false + - in: query + name: session + description: | + Optional stable exec session identifier. When present, websocket disconnects detach + from the command instead of terminating it, and later requests with the same VM name + and session id may reconnect and request buffered history. + schema: + type: string + minLength: 1 + required: false + - in: query + name: cmux_session_id + description: | + Compatibility alias for `session`. Prefer `session` for new Orchard clients. + schema: + type: string + minLength: 1 + required: false - in: query name: stdin description: | @@ -483,7 +505,9 @@ paths: '400': description: Invalid parameters were supplied '404': - description: VM resource with the given name doesn't exist + description: VM resource with the given name or reconnectable exec session doesn't exist + '409': + description: Reconnectable exec session already exists with a different command '503': description: Controller failed to establish a connection with the VM /vms/{name}/ip: @@ -800,10 +824,18 @@ components: description: WebSocket frame from Orchard Client to the Orchard Controller oneOf: - $ref: '#/components/schemas/ExecClientFrameStdin' + - $ref: '#/components/schemas/ExecClientFrameHistory' + - $ref: '#/components/schemas/ExecClientFrameAck' + - $ref: '#/components/schemas/ExecClientFrameDetach' + - $ref: '#/components/schemas/ExecClientFrameClose' discriminator: propertyName: type mapping: stdin: '#/components/schemas/ExecClientFrameStdin' + history: '#/components/schemas/ExecClientFrameHistory' + ack: '#/components/schemas/ExecClientFrameAck' + detach: '#/components/schemas/ExecClientFrameDetach' + close: '#/components/schemas/ExecClientFrameClose' ExecClientFrameStdin: description: Send bytes to the process standard input type: object @@ -822,6 +854,56 @@ components: example: type: stdin data: aGVsbG8K + ExecClientFrameHistory: + description: Request buffered output strictly newer than the supplied watermark + type: object + required: [ type, watermark ] + properties: + type: + type: string + enum: [ history ] + watermark: + type: integer + format: int64 + minimum: 0 + example: + type: history + watermark: 42 + ExecClientFrameAck: + description: Acknowledge that output has been durably consumed through the supplied watermark + type: object + required: [ type, watermark ] + properties: + type: + type: string + enum: [ ack ] + watermark: + type: integer + format: int64 + minimum: 0 + example: + type: ack + watermark: 42 + ExecClientFrameDetach: + description: Detach this websocket while leaving the remote command running + type: object + required: [ type ] + properties: + type: + type: string + enum: [ detach ] + example: + type: detach + ExecClientFrameClose: + description: Close the reconnectable exec session and terminate the remote command + type: object + required: [ type ] + properties: + type: + type: string + enum: [ close ] + example: + type: close ExecControllerFrame: description: WebSocket frame from Orchard Controller to the Orchard Client oneOf: @@ -829,6 +911,7 @@ components: - $ref: '#/components/schemas/ExecControllerFrameStderr' - $ref: '#/components/schemas/ExecControllerFrameExit' - $ref: '#/components/schemas/ExecControllerFrameError' + - $ref: '#/components/schemas/ExecControllerFrameNoMoreHistory' discriminator: propertyName: type mapping: @@ -836,6 +919,7 @@ components: stderr: '#/components/schemas/ExecControllerFrameStderr' exit: '#/components/schemas/ExecControllerFrameExit' error: '#/components/schemas/ExecControllerFrameError' + no_more_history: '#/components/schemas/ExecControllerFrameNoMoreHistory' ExecControllerFrameStdout: description: Standard output from the process type: object @@ -848,6 +932,10 @@ components: type: string format: byte description: Base64-encoded standard output bytes from the process + watermark: + type: integer + format: int64 + description: Monotonic output watermark present on reconnectable sessions example: type: stdout data: aGVsbG8K @@ -863,6 +951,10 @@ components: type: string format: byte description: Base64-encoded standard error bytes from the process + watermark: + type: integer + format: int64 + description: Monotonic output watermark present on reconnectable sessions example: type: stderr data: aGVsbG8K @@ -882,6 +974,10 @@ components: type: integer format: int32 description: Process exit code + watermark: + type: integer + format: int64 + description: Monotonic output watermark present on reconnectable sessions example: type: exit exit: @@ -897,9 +993,28 @@ components: error: type: string description: Error message text + watermark: + type: integer + format: int64 + description: Monotonic output watermark present on reconnectable sessions example: type: error error: Failed to establish SSH connection to a VM + ExecControllerFrameNoMoreHistory: + description: Marker indicating that the requested replay range has been fully sent + type: object + required: [ type, watermark ] + properties: + type: + type: string + enum: [ no_more_history ] + watermark: + type: integer + format: int64 + description: Highest watermark known to the exec session at the time of replay + example: + type: no_more_history + watermark: 42 Event: title: Generic Resource Event type: object diff --git a/internal/command/controller/run.go b/internal/command/controller/run.go index 5406654..8833f93 100644 --- a/internal/command/controller/run.go +++ b/internal/command/controller/run.go @@ -35,6 +35,7 @@ var noExperimentalRPCV2 bool var experimentalPingInterval time.Duration var experimentalDisableDBCompression bool var workerOfflineTimeout time.Duration +var execSessionExitTTL time.Duration var synthetic bool func newRunCommand() *cobra.Command { @@ -87,6 +88,8 @@ func newRunCommand() *cobra.Command { "duration (e.g. 60s or 5m30s) after which a worker is considered offline for the purposes "+ "of scheduling (no new VMs will be scheduled on such worker and already assigned VMs will be "+ "marked as failed)") + cmd.Flags().DurationVar(&execSessionExitTTL, "exec-session-exit-ttl", 10*time.Minute, + "duration to retain reconnectable exec session history after the command exits") // Hidden flags cmd.Flags().BoolVar(&synthetic, "synthetic", false, "") @@ -147,6 +150,7 @@ func runController(cmd *cobra.Command, args []string) (err error) { controller.WithListenAddr(address), controller.WithDataDir(dataDir), controller.WithWorkerOfflineTimeout(workerOfflineTimeout), + controller.WithExecSessionExitTTL(execSessionExitTTL), controller.WithLogger(logger), } diff --git a/internal/controller/api_vms_exec.go b/internal/controller/api_vms_exec.go index de384ca..dafb25c 100644 --- a/internal/controller/api_vms_exec.go +++ b/internal/controller/api_vms_exec.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net" "net/http" "strconv" @@ -28,9 +27,13 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { // Retrieve and parse path and query parameters name := ctx.Param("name") + sessionID := ctx.Query("session") + if sessionID == "" { + sessionID = ctx.Query("cmux_session_id") + } command := ctx.Query("command") - if command == "" { + if sessionID == "" && command == "" { return responder.JSON(http.StatusBadRequest, NewErrorResponse("\"command\" parameter cannot be empty")) } @@ -43,6 +46,20 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { return responder.Code(http.StatusBadRequest) } + if sessionID != "" { + return controller.execVMReconnectable(ctx, name, sessionID, command, stdin, wait) + } + + return controller.execVMLegacy(ctx, name, command, stdin, wait) +} + +func (controller *Controller) execVMLegacy( + ctx *gin.Context, + name string, + command string, + stdin bool, + wait uint64, +) responder.Responder { // Look-up the VM waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second) defer waitContextCancel() @@ -52,33 +69,27 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { return responderImpl } - // Establish a port-forwarding connection to a VM's SSH port - portForwardConn, err := retry.NewWithData[net.Conn]( - retry.Context(waitContext), - retry.DelayType(retry.FixedDelay), - retry.Delay(time.Second), - retry.Attempts(0), - retry.LastErrorOnly(true), - ).Do(func() (net.Conn, error) { - return controller.portForwardConnection(ctx, waitContext, vm.Worker, vm.UID, 22) - }) + session, err := controller.newSSHExecSession( + ctx, + waitContext, + vm, + execSessionKey{vmName: name}, + command, + stdin, + nil, + legacyExecSessionPolicy, + ) if err != nil { return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err)) } - defer portForwardConn.Close() - - // Establish an SSH connection to a VM - exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin) - if err != nil { - return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err)) - } - defer exec.Close() // Upgrade HTTP request to a WebSocket connection wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, }) if err != nil { + session.closeIfUnused() + return responder.Error(err) } defer func() { @@ -89,56 +100,177 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { _ = wsConn.CloseNow() }() - // Read WebSocket frames - readFramesErrCh := make(chan error, 1) - go func() { - readFramesErrCh <- controller.readFrames(ctx, wsConn, exec.Stdin()) + return controller.serveExecSession(ctx, wsConn, session) +} + +func (controller *Controller) execVMReconnectable( + ctx *gin.Context, + name string, + sessionID string, + command string, + stdin bool, + wait uint64, +) responder.Responder { + key := execSessionKey{ + vmName: name, + sessionID: sessionID, + } + + session, ok := controller.execSessions.get(key) + if ok { + if !session.commandMatches(command) { + return responder.JSON(http.StatusConflict, + NewErrorResponse("exec session %q is already running a different command", sessionID)) + } + } else { + if command == "" { + return responder.JSON(http.StatusNotFound, + NewErrorResponse("exec session %q does not exist", sessionID)) + } + + waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second) + defer waitContextCancel() + + vm, responderImpl := controller.waitForVM(waitContext, name) + if responderImpl != nil { + return responderImpl + } + + var err error + session, _, err = controller.execSessions.getOrCreate(waitContext, key, func() (*execSession, error) { + return controller.newSSHExecSession( + ctx, + waitContext, + vm, + key, + command, + stdin, + controller.execSessions, + reconnectableExecSessionPolicy, + ) + }) + if err != nil { + return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err)) + } + + if !session.commandMatches(command) { + return responder.JSON(http.StatusConflict, + NewErrorResponse("exec session %q is already running a different command", sessionID)) + } + } + + wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + if err != nil { + session.closeIfUnused() + + return responder.Error(err) + } + defer func() { + _ = wsConn.CloseNow() }() - // Run the command - sshErrCh := make(chan error, 1) - outgoingFrames := make(chan *execstream.Frame) + return controller.serveExecSession(ctx, wsConn, session) +} + +func (controller *Controller) newSSHExecSession( + _ *gin.Context, + waitContext context.Context, + vm *v1.VM, + key execSessionKey, + command string, + stdin bool, + registry *execSessionRegistry, + policy execSessionPolicy, +) (*execSession, error) { + sessionContext, sessionContextCancel := context.WithCancel(context.Background()) + + portForwardConn, err := retry.NewWithData[net.Conn]( + retry.Context(waitContext), + retry.DelayType(retry.FixedDelay), + retry.Delay(time.Second), + retry.Attempts(0), + retry.LastErrorOnly(true), + ).Do(func() (net.Conn, error) { + return controller.portForwardConnection(sessionContext, waitContext, vm.Worker, vm.UID, 22) + }) + if err != nil { + sessionContextCancel() + + return nil, err + } + + exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin) + if err != nil { + sessionContextCancel() + _ = portForwardConn.Close() + + return nil, fmt.Errorf("failed to establish SSH connection to a VM: %w", err) + } + + return newExecSessionWithContext( + sessionContext, + sessionContextCancel, + key, + command, + exec, + portForwardConn, + registry, + controller.execSessionExitTTL, + policy, + ), nil +} + +func (controller *Controller) serveExecSession( + ctx *gin.Context, + wsConn *websocket.Conn, + session *execSession, +) responder.Responder { + subscriber, err := session.attach() + if err != nil { + _ = wsConn.Close(websocket.StatusNormalClosure, err.Error()) + + return responder.Empty() + } + defer session.detach(subscriber) + session.start() + + readFramesErrCh := make(chan error, 1) go func() { - sshErrCh <- exec.Run(ctx, command, outgoingFrames) + readFramesErrCh <- controller.readExecSessionFrames(ctx, wsConn, session, subscriber) }() for { select { case readFramesErr := <-readFramesErrCh: - controller.logger.Warnf("failed to read and process frames from WebSocket: %v", readFramesErr) + if readFramesErr != nil && + !errors.Is(readFramesErr, errExecSessionDetached) && + !errors.Is(readFramesErr, errExecSessionClosed) { + controller.logger.Warnf("failed to read and process exec frames from WebSocket: %v", + readFramesErr) + } return responder.Empty() - case outgoingFrame := <-outgoingFrames: - if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil { - controller.logger.Warnf("failed to write WebSocket frame to the client: %v", err) + case outgoingFrame, ok := <-subscriber.frames: + if !ok { + if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil { + controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err) + } return responder.Empty() } - case sshErr := <-sshErrCh: - if sshErr != nil { - if err := execstream.WriteFrame(ctx, wsConn, &execstream.Frame{ - Type: execstream.FrameTypeError, - Error: sshErr.Error(), - }); err != nil { - controller.logger.Warnf("exec: failed to write error frame to WebSocket: %v", err) - } - } - if err := wsConn.Close(websocket.StatusNormalClosure, "Command finished"); err != nil { - controller.logger.Warnf("exec: failed to close WebSocket cleanly: %v", err) - } + if err := execstream.WriteFrame(ctx, wsConn, outgoingFrame); err != nil { + controller.logger.Warnf("failed to write exec frame to the client: %v", err) - if readFramesErrCh != nil { - // Read() on a WebSocket should unblock shortly after calling Close() - <-readFramesErrCh + return responder.Empty() } - - return responder.Empty() case <-time.After(controller.pingInterval): pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second) if err := wsConn.Ping(pingCtx); err != nil { - controller.logger.Warnf("port forwarding: failed to ping the client, "+ + controller.logger.Warnf("exec: failed to ping the client, "+ "connection might time out: %v", err) } @@ -151,10 +283,16 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { } } -func (controller *Controller) readFrames( +var ( + errExecSessionDetached = errors.New("exec session detached") + errExecSessionClosed = errors.New("exec session closed") +) + +func (controller *Controller) readExecSessionFrames( ctx context.Context, wsConn *websocket.Conn, - stdinHandle io.WriteCloser, + session *execSession, + subscriber *execSessionSubscriber, ) error { for { var frame execstream.Frame @@ -163,7 +301,7 @@ func (controller *Controller) readFrames( if err != nil { var closeErr websocket.CloseError if errors.As(err, &closeErr) && closeErr.Code == websocket.StatusNormalClosure { - return nil + return errExecSessionDetached } return fmt.Errorf("failed to read next frame from WebSocket: %w", err) @@ -179,26 +317,35 @@ func (controller *Controller) readFrames( switch frame.Type { case execstream.FrameTypeStdin: - if stdinHandle == nil { - return fmt.Errorf("failed to handle %q frame: this exec session "+ - "has no stdin is enabled or already closed", frame.Type) + if err := session.writeStdin(frame.Data); err != nil { + return fmt.Errorf("failed to handle %q frame: %w", frame.Type, err) + } + case execstream.FrameTypeHistory: + if !session.policy.replayEnabled { + return fmt.Errorf("unexpected frame type received: %q", frame.Type) } - if len(frame.Data) == 0 { - if err := stdinHandle.Close(); err != nil { - return fmt.Errorf("failed to handle %q frame: failed to close "+ - "stdin: %w", frame.Type, err) - } - - stdinHandle = nil - - continue + session.sendHistory(subscriber, frame.Watermark) + case execstream.FrameTypeAck: + if !session.policy.replayEnabled { + return fmt.Errorf("unexpected frame type received: %q", frame.Type) } - if _, err := stdinHandle.Write(frame.Data); err != nil { - return fmt.Errorf("failed to handle %q frame: failed to write "+ - "to stdin: %w", frame.Type, err) + session.ack(frame.Watermark) + case execstream.FrameTypeDetach: + if !session.policy.replayEnabled { + return fmt.Errorf("unexpected frame type received: %q", frame.Type) } + + return errExecSessionDetached + case execstream.FrameTypeClose: + if !session.policy.replayEnabled { + return fmt.Errorf("unexpected frame type received: %q", frame.Type) + } + + session.close() + + return errExecSessionClosed default: return fmt.Errorf("unexpected frame type received: %q", frame.Type) } diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 906f9d2..1521e0d 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -55,6 +55,7 @@ type Controller struct { ipRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[string]] enableSwaggerDocs bool workerOfflineTimeout time.Duration + execSessionExitTTL time.Duration experimentalRPCV2 bool disableDBCompression bool pingInterval time.Duration @@ -64,6 +65,7 @@ type Controller struct { sshSigner ssh.Signer sshNoClientAuth bool sshServer *sshserver.SSHServer + execSessions *execSessionRegistry single singleflight.Group @@ -75,7 +77,9 @@ func New(opts ...Option) (*Controller, error) { connRendezvous: rendezvous.New[rendezvous.ResultWithErrorMessage[net.Conn]](), ipRendezvous: rendezvous.New[rendezvous.ResultWithErrorMessage[string]](), workerOfflineTimeout: 3 * time.Minute, + execSessionExitTTL: 10 * time.Minute, pingInterval: 30 * time.Second, + execSessions: newExecSessionRegistry(), single: singleflight.Group{}, } @@ -308,6 +312,8 @@ func (controller *Controller) Run(ctx context.Context) error { go func() { <-ctx.Done() + controller.execSessions.closeAll() + if err := controller.httpServer.Shutdown(ctx); err != nil { controller.logger.Errorf("failed to cleanly shutdown the HTTP server: %v", err) } diff --git a/internal/controller/exec_sessions.go b/internal/controller/exec_sessions.go new file mode 100644 index 0000000..7357cec --- /dev/null +++ b/internal/controller/exec_sessions.go @@ -0,0 +1,692 @@ +package controller + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/cirruslabs/orchard/internal/execstream" +) + +const execSessionReplayBufferBytes = 4 * 1024 * 1024 + +type execSessionPolicy struct { + closeOnDetach bool + retainAfterExit bool + replayEnabled bool +} + +var ( + legacyExecSessionPolicy = execSessionPolicy{ + closeOnDetach: true, + } + reconnectableExecSessionPolicy = execSessionPolicy{ + retainAfterExit: true, + replayEnabled: true, + } +) + +type sshExecRunner interface { + Stdin() io.WriteCloser + Run(ctx context.Context, command string, outgoingFrames chan<- *execstream.Frame) error + Close() error +} + +type execSessionKey struct { + vmName string + sessionID string +} + +type execSessionCreation struct { + done chan struct{} + session *execSession + err error +} + +type execSessionRegistry struct { + mu sync.Mutex + sessions map[execSessionKey]*execSession + creating map[execSessionKey]*execSessionCreation +} + +func newExecSessionRegistry() *execSessionRegistry { + return &execSessionRegistry{ + sessions: map[execSessionKey]*execSession{}, + creating: map[execSessionKey]*execSessionCreation{}, + } +} + +func (registry *execSessionRegistry) get(key execSessionKey) (*execSession, bool) { + registry.mu.Lock() + defer registry.mu.Unlock() + + session, ok := registry.sessions[key] + + return session, ok +} + +func (registry *execSessionRegistry) getOrCreate( + ctx context.Context, + key execSessionKey, + create func() (*execSession, error), +) (*execSession, bool, error) { + registry.mu.Lock() + + if session, ok := registry.sessions[key]; ok { + registry.mu.Unlock() + + return session, false, nil + } + + if creation, ok := registry.creating[key]; ok { + registry.mu.Unlock() + + select { + case <-ctx.Done(): + return nil, false, ctx.Err() + case <-creation.done: + if creation.err != nil { + return nil, false, creation.err + } + + return creation.session, false, nil + } + } + + creation := &execSessionCreation{done: make(chan struct{})} + registry.creating[key] = creation + registry.mu.Unlock() + + session, err := create() + + registry.mu.Lock() + delete(registry.creating, key) + if err == nil { + registry.sessions[key] = session + } + creation.session = session + creation.err = err + close(creation.done) + registry.mu.Unlock() + + return session, true, err +} + +func (registry *execSessionRegistry) remove(key execSessionKey, expected *execSession) { + registry.mu.Lock() + defer registry.mu.Unlock() + + if registry.sessions[key] == expected { + delete(registry.sessions, key) + } +} + +func (registry *execSessionRegistry) closeAll() { + registry.mu.Lock() + sessions := make([]*execSession, 0, len(registry.sessions)) + for _, session := range registry.sessions { + sessions = append(sessions, session) + } + registry.mu.Unlock() + + for _, session := range sessions { + session.close() + } +} + +type execReplayFrame struct { + frame *execstream.Frame + size int +} + +type execReplayBuffer struct { + frames []execReplayFrame + bufferBytes int + nextWatermark uint64 + ackedWatermark uint64 +} + +func (buffer *execReplayBuffer) append(frame *execstream.Frame) *execstream.Frame { + frame = cloneExecFrame(frame) + buffer.nextWatermark++ + frame.Watermark = buffer.nextWatermark + + frameSize := execFrameSize(frame) + buffer.frames = append(buffer.frames, execReplayFrame{ + frame: frame, + size: frameSize, + }) + buffer.bufferBytes += frameSize + buffer.trimAcknowledged() + buffer.trimToLimit() + + return frame +} + +func (buffer *execReplayBuffer) ack(watermark uint64) { + if watermark <= buffer.ackedWatermark { + return + } + + buffer.ackedWatermark = watermark + buffer.trimAcknowledged() +} + +func (buffer *execReplayBuffer) replayAfter( + watermark uint64, + frames []*execstream.Frame, +) []*execstream.Frame { + for _, record := range buffer.frames { + if record.frame.Watermark <= watermark { + continue + } + + frames = append(frames, record.frame) + } + + return frames +} + +func (buffer *execReplayBuffer) trimAcknowledged() { + for len(buffer.frames) > 0 && buffer.frames[0].frame.Watermark <= buffer.ackedWatermark { + buffer.bufferBytes -= buffer.frames[0].size + buffer.frames = buffer.frames[1:] + } +} + +func (buffer *execReplayBuffer) trimToLimit() { + for buffer.bufferBytes > execSessionReplayBufferBytes && len(buffer.frames) > 0 { + buffer.bufferBytes -= buffer.frames[0].size + buffer.frames = buffer.frames[1:] + } +} + +type execSessionSubscriber struct { + frames chan *execstream.Frame + closed chan struct{} + closeOnce sync.Once + sendMu sync.Mutex + sentWatermark uint64 +} + +func newExecSessionSubscriber() *execSessionSubscriber { + return &execSessionSubscriber{ + frames: make(chan *execstream.Frame, 128), + closed: make(chan struct{}), + } +} + +func (subscriber *execSessionSubscriber) enqueue(frame *execstream.Frame) bool { + subscriber.sendMu.Lock() + defer subscriber.sendMu.Unlock() + + if subscriber.alreadySentLocked(frame) { + return true + } + + select { + case <-subscriber.closed: + return false + default: + } + + select { + case subscriber.frames <- subscriber.markSentLocked(frame): + return true + case <-subscriber.closed: + return false + default: + return false + } +} + +func (subscriber *execSessionSubscriber) sendHistory(frames []*execstream.Frame) bool { + for _, frame := range frames { + if !subscriber.sendLocked(frame) { + return false + } + } + + return true +} + +func (subscriber *execSessionSubscriber) sendLocked(frame *execstream.Frame) bool { + if subscriber.alreadySentLocked(frame) { + return true + } + + select { + case <-subscriber.closed: + return false + default: + } + + select { + case subscriber.frames <- subscriber.markSentLocked(frame): + return true + case <-subscriber.closed: + return false + } +} + +func (subscriber *execSessionSubscriber) alreadySentLocked(frame *execstream.Frame) bool { + return isReplayOutputFrame(frame) && + frame.Watermark != 0 && + frame.Watermark <= subscriber.sentWatermark +} + +func (subscriber *execSessionSubscriber) markSentLocked(frame *execstream.Frame) *execstream.Frame { + frame = cloneExecFrame(frame) + if isReplayOutputFrame(frame) && frame.Watermark > subscriber.sentWatermark { + subscriber.sentWatermark = frame.Watermark + } + + return frame +} + +func (subscriber *execSessionSubscriber) close() { + subscriber.closeOnce.Do(func() { + close(subscriber.closed) + subscriber.sendMu.Lock() + close(subscriber.frames) + subscriber.sendMu.Unlock() + }) +} + +type execSession struct { + key execSessionKey + command string + exec sshExecRunner + transport net.Conn + registry *execSessionRegistry + exitTTL time.Duration + policy execSessionPolicy + + ctx context.Context + cancel context.CancelFunc + + mu sync.Mutex + stdin io.WriteCloser + stdinClosed bool + subscribers map[*execSessionSubscriber]struct{} + replay execReplayBuffer + started bool + finished bool + closed bool + expiryTimer *time.Timer + + startOnce sync.Once + done chan struct{} + doneOnce sync.Once +} + +func newExecSession( + key execSessionKey, + command string, + exec sshExecRunner, + transport net.Conn, + registry *execSessionRegistry, + exitTTL time.Duration, + policy execSessionPolicy, +) *execSession { + ctx, cancel := context.WithCancel(context.Background()) + + return newExecSessionWithContext( + ctx, + cancel, + key, + command, + exec, + transport, + registry, + exitTTL, + policy, + ) +} + +func newExecSessionWithContext( + ctx context.Context, + cancel context.CancelFunc, + key execSessionKey, + command string, + exec sshExecRunner, + transport net.Conn, + registry *execSessionRegistry, + exitTTL time.Duration, + policy execSessionPolicy, +) *execSession { + if ctx == nil || cancel == nil { + ctx, cancel = context.WithCancel(context.Background()) + } + + session := &execSession{ + key: key, + command: command, + exec: exec, + transport: transport, + registry: registry, + exitTTL: exitTTL, + policy: policy, + ctx: ctx, + cancel: cancel, + stdin: exec.Stdin(), + subscribers: map[*execSessionSubscriber]struct{}{}, + done: make(chan struct{}), + } + + return session +} + +func (session *execSession) commandMatches(command string) bool { + return command == "" || session.command == command +} + +func (session *execSession) start() { + session.startOnce.Do(func() { + session.mu.Lock() + if session.closed { + session.mu.Unlock() + + return + } + session.started = true + session.mu.Unlock() + + go session.run() + }) +} + +func (session *execSession) closeIfUnused() { + session.mu.Lock() + unused := !session.started && len(session.subscribers) == 0 + session.mu.Unlock() + + if unused { + session.close() + } +} + +func (session *execSession) attach() (*execSessionSubscriber, error) { + session.mu.Lock() + defer session.mu.Unlock() + + if session.closed { + return nil, errors.New("exec session is closed") + } + + subscriber := newExecSessionSubscriber() + session.subscribers[subscriber] = struct{}{} + + return subscriber, nil +} + +func (session *execSession) detach(subscriber *execSessionSubscriber) { + if session.policy.closeOnDetach { + session.close() + + return + } + + session.mu.Lock() + defer session.mu.Unlock() + + session.detachLocked(subscriber) +} + +func (session *execSession) detachLocked(subscriber *execSessionSubscriber) { + if _, ok := session.subscribers[subscriber]; !ok { + return + } + + delete(session.subscribers, subscriber) + subscriber.close() +} + +func (session *execSession) writeStdin(data []byte) error { + session.mu.Lock() + defer session.mu.Unlock() + + if session.stdin == nil || session.stdinClosed { + return errors.New("this exec session has no stdin enabled or it is already closed") + } + + if len(data) == 0 { + if err := session.stdin.Close(); err != nil { + return err + } + + session.stdinClosed = true + + return nil + } + + _, err := session.stdin.Write(data) + + return err +} + +func (session *execSession) ack(watermark uint64) { + if !session.policy.replayEnabled { + return + } + + session.mu.Lock() + defer session.mu.Unlock() + + session.replay.ack(watermark) +} + +func (session *execSession) sendHistory( + subscriber *execSessionSubscriber, + watermark uint64, +) { + if !session.policy.replayEnabled { + return + } + + session.mu.Lock() + + if _, ok := session.subscribers[subscriber]; !ok { + session.mu.Unlock() + + return + } + + subscriber.sendMu.Lock() + frames := session.replay.replayAfter(watermark, nil) + frames = append(frames, &execstream.Frame{ + Type: execstream.FrameTypeNoMoreHistory, + Watermark: session.replay.nextWatermark, + }) + session.mu.Unlock() + + ok := subscriber.sendHistory(frames) + subscriber.sendMu.Unlock() + + if !ok { + session.dropSubscriber(subscriber) + } +} + +func (session *execSession) close() { + session.mu.Lock() + if session.closed { + session.mu.Unlock() + + return + } + + session.closed = true + if session.expiryTimer != nil { + session.expiryTimer.Stop() + session.expiryTimer = nil + } + + subscribers := session.takeSubscribersLocked() + session.mu.Unlock() + + closeSubscribers(subscribers) + + session.cancel() + _ = session.exec.Close() + if session.transport != nil { + _ = session.transport.Close() + } + if session.registry != nil { + session.registry.remove(session.key, session) + } +} + +func (session *execSession) run() { + outgoingFrames := make(chan *execstream.Frame) + runErrCh := make(chan error, 1) + + go func() { + runErrCh <- session.exec.Run(session.ctx, session.command, outgoingFrames) + close(outgoingFrames) + }() + + for frame := range outgoingFrames { + session.recordFrame(frame) + } + + runErr := <-runErrCh + if runErr != nil && !errors.Is(runErr, context.Canceled) { + session.recordFrame(&execstream.Frame{ + Type: execstream.FrameTypeError, + Error: runErr.Error(), + }) + } + + session.markFinished() +} + +func (session *execSession) recordFrame(frame *execstream.Frame) { + session.mu.Lock() + + if session.closed { + session.mu.Unlock() + + return + } + + if session.policy.replayEnabled { + frame = session.replay.append(frame) + } else { + frame = cloneExecFrame(frame) + } + + subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers)) + for subscriber := range session.subscribers { + subscribers = append(subscribers, subscriber) + } + session.mu.Unlock() + + for _, subscriber := range subscribers { + if !subscriber.enqueue(frame) { + session.dropSubscriber(subscriber) + } + } +} + +func (session *execSession) markFinished() { + session.mu.Lock() + if session.finished { + session.mu.Unlock() + + return + } + + session.finished = true + shouldClose := !session.policy.retainAfterExit + if !session.closed && session.policy.retainAfterExit { + session.expiryTimer = time.AfterFunc(session.exitTTL, session.expire) + } + + var subscribers []*execSessionSubscriber + if shouldClose { + subscribers = session.takeSubscribersLocked() + } + session.mu.Unlock() + + closeSubscribers(subscribers) + + session.doneOnce.Do(func() { + close(session.done) + }) + + if shouldClose { + session.close() + } +} + +func (session *execSession) expire() { + session.close() +} + +func (session *execSession) takeSubscribersLocked() []*execSessionSubscriber { + subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers)) + for subscriber := range session.subscribers { + subscribers = append(subscribers, subscriber) + } + session.subscribers = map[*execSessionSubscriber]struct{}{} + + return subscribers +} + +func closeSubscribers(subscribers []*execSessionSubscriber) { + for _, subscriber := range subscribers { + subscriber.close() + } +} + +func (session *execSession) dropSubscriber(subscriber *execSessionSubscriber) { + session.mu.Lock() + defer session.mu.Unlock() + + session.detachLocked(subscriber) +} + +func cloneExecFrame(frame *execstream.Frame) *execstream.Frame { + if frame == nil { + return nil + } + + clone := *frame + if frame.Data != nil { + clone.Data = append([]byte(nil), frame.Data...) + } + if frame.Exit != nil { + exit := *frame.Exit + clone.Exit = &exit + } + + return &clone +} + +func execFrameSize(frame *execstream.Frame) int { + if frame == nil { + return 0 + } + + return len(frame.Data) + len(frame.Error) + 16 +} + +func isReplayOutputFrame(frame *execstream.Frame) bool { + if frame == nil { + return false + } + + switch frame.Type { + case execstream.FrameTypeStdout, + execstream.FrameTypeStderr, + execstream.FrameTypeExit, + execstream.FrameTypeError: + return true + default: + return false + } +} diff --git a/internal/controller/exec_sessions_test.go b/internal/controller/exec_sessions_test.go new file mode 100644 index 0000000..cd19eb6 --- /dev/null +++ b/internal/controller/exec_sessions_test.go @@ -0,0 +1,320 @@ +package controller + +import ( + "context" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/cirruslabs/orchard/internal/execstream" + "github.com/stretchr/testify/require" +) + +type fakeExec struct { + stdin io.WriteCloser + run func(context.Context, string, chan<- *execstream.Frame) error + closeCalls atomic.Int32 +} + +func (exec *fakeExec) Stdin() io.WriteCloser { + return exec.stdin +} + +func (exec *fakeExec) Run( + ctx context.Context, + command string, + outgoingFrames chan<- *execstream.Frame, +) error { + if exec.run != nil { + return exec.run(ctx, command, outgoingFrames) + } + + return nil +} + +func (exec *fakeExec) Close() error { + exec.closeCalls.Add(1) + + return nil +} + +func newManualExecSessionForTest( + key execSessionKey, + registry *execSessionRegistry, +) *execSession { + ctx, cancel := context.WithCancel(context.Background()) + + return &execSession{ + key: key, + command: "echo test", + exec: &fakeExec{}, + registry: registry, + exitTTL: time.Minute, + policy: reconnectableExecSessionPolicy, + ctx: ctx, + cancel: cancel, + subscribers: map[*execSessionSubscriber]struct{}{}, + done: make(chan struct{}), + } +} + +func TestExecSessionRegistryGetOrCreateReusesInflightCreation(t *testing.T) { + registry := newExecSessionRegistry() + key := execSessionKey{vmName: "vm", sessionID: "session"} + + createStarted := make(chan struct{}) + releaseCreate := make(chan struct{}) + var createCalls atomic.Int32 + + create := func() (*execSession, error) { + createCalls.Add(1) + close(createStarted) + <-releaseCreate + + return newManualExecSessionForTest(key, registry), nil + } + + firstDone := make(chan struct{}) + go func() { + defer close(firstDone) + _, _, err := registry.getOrCreate(context.Background(), key, create) + require.NoError(t, err) + }() + + <-createStarted + + secondDone := make(chan struct{}) + go func() { + defer close(secondDone) + _, created, err := registry.getOrCreate(context.Background(), key, create) + require.NoError(t, err) + require.False(t, created) + }() + + close(releaseCreate) + + <-firstDone + <-secondDone + require.EqualValues(t, 1, createCalls.Load()) +} + +func TestExecSessionStartRunsCommandOnlyOnce(t *testing.T) { + var runCalls atomic.Int32 + runStarted := make(chan struct{}) + + session := newExecSession( + execSessionKey{vmName: "vm", sessionID: "session"}, + "echo test", + &fakeExec{ + run: func(ctx context.Context, _ string, _ chan<- *execstream.Frame) error { + runCalls.Add(1) + close(runStarted) + <-ctx.Done() + + return ctx.Err() + }, + }, + nil, + nil, + time.Minute, + reconnectableExecSessionPolicy, + ) + defer session.close() + + session.start() + session.start() + + <-runStarted + require.EqualValues(t, 1, runCalls.Load()) +} + +func TestExecSessionHistoryReplayAndAck(t *testing.T) { + registry := newExecSessionRegistry() + session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry) + + session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")}) + session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStderr, Data: []byte("err")}) + session.recordFrame(&execstream.Frame{ + Type: execstream.FrameTypeExit, + Exit: &execstream.Exit{Code: 7}, + }) + + subscriber, err := session.attach() + require.NoError(t, err) + + session.sendHistory(subscriber, 0) + + require.Equal(t, execstream.FrameTypeStdout, (<-subscriber.frames).Type) + require.Equal(t, execstream.FrameTypeStderr, (<-subscriber.frames).Type) + require.Equal(t, execstream.FrameTypeExit, (<-subscriber.frames).Type) + noMoreHistory := <-subscriber.frames + require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type) + require.EqualValues(t, 3, noMoreHistory.Watermark) + + session.ack(2) + require.Len(t, session.replay.frames, 1) + require.EqualValues(t, 3, session.replay.frames[0].frame.Watermark) +} + +func TestExecSessionHistoryReplayStreamsPastSubscriberBuffer(t *testing.T) { + registry := newExecSessionRegistry() + session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry) + + const frameCount = 256 + for i := 0; i < frameCount; i++ { + session.recordFrame(&execstream.Frame{ + Type: execstream.FrameTypeStdout, + Data: []byte{byte(i)}, + }) + } + + subscriber, err := session.attach() + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + session.sendHistory(subscriber, 0) + }() + + for i := 1; i <= frameCount; i++ { + frame := <-subscriber.frames + require.Equal(t, execstream.FrameTypeStdout, frame.Type) + require.EqualValues(t, i, frame.Watermark) + } + + noMoreHistory := <-subscriber.frames + require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type) + require.EqualValues(t, frameCount, noMoreHistory.Watermark) + + require.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, time.Second, 10*time.Millisecond) +} + +func TestExecSessionDetachKeepsProcessAlive(t *testing.T) { + registry := newExecSessionRegistry() + session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry) + exec := session.exec.(*fakeExec) + + subscriber, err := session.attach() + require.NoError(t, err) + + session.detach(subscriber) + + require.False(t, session.closed) + require.EqualValues(t, 0, exec.closeCalls.Load()) +} + +func TestLegacyExecSessionDetachStopsProcess(t *testing.T) { + registry := newExecSessionRegistry() + session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry) + session.policy = legacyExecSessionPolicy + exec := session.exec.(*fakeExec) + + subscriber, err := session.attach() + require.NoError(t, err) + + session.detach(subscriber) + + require.True(t, session.closed) + require.EqualValues(t, 1, exec.closeCalls.Load()) +} + +func TestLegacyExecSessionDoesNotRetainReplayHistory(t *testing.T) { + registry := newExecSessionRegistry() + session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry) + session.policy = legacyExecSessionPolicy + + session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")}) + + require.Empty(t, session.replay.frames) + require.Zero(t, session.replay.nextWatermark) +} + +func TestExecSessionCloseIfUnusedClosesIdleSession(t *testing.T) { + registry := newExecSessionRegistry() + key := execSessionKey{vmName: "vm", sessionID: "session"} + session := newManualExecSessionForTest(key, registry) + exec := session.exec.(*fakeExec) + registry.sessions[key] = session + + session.closeIfUnused() + + require.True(t, session.closed) + require.EqualValues(t, 1, exec.closeCalls.Load()) +} + +func TestExecSessionCloseIfUnusedKeepsAttachedSession(t *testing.T) { + registry := newExecSessionRegistry() + session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry) + + _, err := session.attach() + require.NoError(t, err) + + session.closeIfUnused() + + require.False(t, session.closed) +} + +func TestExecSessionCloseStopsProcessAndRemovesRegistryEntry(t *testing.T) { + registry := newExecSessionRegistry() + key := execSessionKey{vmName: "vm", sessionID: "session"} + session := newManualExecSessionForTest(key, registry) + exec := session.exec.(*fakeExec) + registry.sessions[key] = session + + session.close() + + require.True(t, session.closed) + require.EqualValues(t, 1, exec.closeCalls.Load()) + _, ok := registry.get(key) + require.False(t, ok) +} + +func TestExecSessionFinishedEntryExpiresAfterTTL(t *testing.T) { + registry := newExecSessionRegistry() + key := execSessionKey{vmName: "vm", sessionID: "session"} + session := newManualExecSessionForTest(key, registry) + session.exitTTL = 10 * time.Millisecond + registry.sessions[key] = session + + session.markFinished() + + require.Eventually(t, func() bool { + _, ok := registry.get(key) + + return !ok + }, time.Second, 10*time.Millisecond) +} + +func TestExecSessionFinishKeepsReconnectableSubscriberOpen(t *testing.T) { + registry := newExecSessionRegistry() + session := newManualExecSessionForTest(execSessionKey{vmName: "vm", sessionID: "session"}, registry) + + subscriber, err := session.attach() + require.NoError(t, err) + + session.recordFrame(&execstream.Frame{Type: execstream.FrameTypeStdout, Data: []byte("out")}) + session.recordFrame(&execstream.Frame{ + Type: execstream.FrameTypeExit, + Exit: &execstream.Exit{Code: 0}, + }) + session.markFinished() + + require.Equal(t, execstream.FrameTypeStdout, (<-subscriber.frames).Type) + require.Equal(t, execstream.FrameTypeExit, (<-subscriber.frames).Type) + + session.sendHistory(subscriber, 0) + + noMoreHistory, ok := <-subscriber.frames + require.True(t, ok) + require.Equal(t, execstream.FrameTypeNoMoreHistory, noMoreHistory.Type) + require.EqualValues(t, 2, noMoreHistory.Watermark) +} diff --git a/internal/controller/option.go b/internal/controller/option.go index 8370d50..325c912 100644 --- a/internal/controller/option.go +++ b/internal/controller/option.go @@ -60,6 +60,12 @@ func WithWorkerOfflineTimeout(workerOfflineTimeout time.Duration) Option { } } +func WithExecSessionExitTTL(execSessionExitTTL time.Duration) Option { + return func(controller *Controller) { + controller.execSessionExitTTL = execSessionExitTTL + } +} + func WithExperimentalRPCV2() Option { return func(controller *Controller) { controller.experimentalRPCV2 = true diff --git a/internal/controller/sshexec/sshexec.go b/internal/controller/sshexec/sshexec.go index 41c8f70..bb412c3 100644 --- a/internal/controller/sshexec/sshexec.go +++ b/internal/controller/sshexec/sshexec.go @@ -14,11 +14,12 @@ import ( ) type Exec struct { - sshClient *ssh.Client - sshSession *ssh.Session - stdout io.Reader - stderr io.Reader - stdin io.WriteCloser + sshClient *ssh.Client + sshSession *ssh.Session + stdout io.Reader + stderr io.Reader + stdin io.WriteCloser + stdinReader *io.PipeReader } func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, error) { @@ -52,14 +53,10 @@ func New(netConn net.Conn, user string, password string, stdin bool) (*Exec, err } if stdin { - exec.stdin, err = sshSession.StdinPipe() - if err != nil { - _ = sshSession.Close() - _ = sshClient.Close() - - return nil, fmt.Errorf("failed to create standard input pipe "+ - "for the SSH session: %w", err) - } + stdinReader, stdinWriter := io.Pipe() + sshSession.Stdin = stdinReader + exec.stdinReader = stdinReader + exec.stdin = stdinWriter } exec.stdout, err = sshSession.StdoutPipe() @@ -92,6 +89,12 @@ func (exec *Exec) Run( command string, outgoingFrames chan<- *execstream.Frame, ) error { + if exec.stdinReader != nil { + defer func() { + _ = exec.stdinReader.Close() + }() + } + if err := exec.sshSession.Start(command); err != nil { return fmt.Errorf("failed to start command %q: %w", command, err) } @@ -188,6 +191,13 @@ func ioStreamReader( } func (exec *Exec) Close() error { + if exec.stdin != nil { + _ = exec.stdin.Close() + } + if exec.stdinReader != nil { + _ = exec.stdinReader.Close() + } + if err := exec.sshSession.Close(); err != nil { _ = exec.sshClient.Close() diff --git a/internal/execstream/frame.go b/internal/execstream/frame.go index c40f252..ab41f0b 100644 --- a/internal/execstream/frame.go +++ b/internal/execstream/frame.go @@ -10,18 +10,24 @@ import ( type FrameType string const ( - FrameTypeStdin FrameType = "stdin" - FrameTypeStdout FrameType = "stdout" - FrameTypeStderr FrameType = "stderr" - FrameTypeExit FrameType = "exit" - FrameTypeError FrameType = "error" + FrameTypeStdin FrameType = "stdin" + FrameTypeStdout FrameType = "stdout" + FrameTypeStderr FrameType = "stderr" + FrameTypeExit FrameType = "exit" + FrameTypeError FrameType = "error" + FrameTypeHistory FrameType = "history" + FrameTypeNoMoreHistory FrameType = "no_more_history" + FrameTypeAck FrameType = "ack" + FrameTypeDetach FrameType = "detach" + FrameTypeClose FrameType = "close" ) type Frame struct { - Type FrameType `json:"type"` - Data []byte `json:"data,omitempty"` - Exit *Exit `json:"exit,omitempty"` - Error string `json:"error,omitempty"` + Type FrameType `json:"type"` + Data []byte `json:"data,omitempty"` + Exit *Exit `json:"exit,omitempty"` + Error string `json:"error,omitempty"` + Watermark uint64 `json:"watermark,omitempty"` } type Exit struct { diff --git a/internal/execstream/frame_test.go b/internal/execstream/frame_test.go new file mode 100644 index 0000000..673f0c3 --- /dev/null +++ b/internal/execstream/frame_test.go @@ -0,0 +1,23 @@ +package execstream + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFrameRoundTripsWatermark(t *testing.T) { + frame := Frame{ + Type: FrameTypeHistory, + Watermark: 42, + } + + payload, err := json.Marshal(frame) + require.NoError(t, err) + + var decoded Frame + err = json.Unmarshal(payload, &decoded) + require.NoError(t, err) + require.Equal(t, frame, decoded) +} diff --git a/internal/tests/exec_test.go b/internal/tests/exec_test.go index 227f66d..6d2005c 100644 --- a/internal/tests/exec_test.go +++ b/internal/tests/exec_test.go @@ -2,6 +2,7 @@ package tests_test import ( "bytes" + "context" "encoding/json" "testing" "time" @@ -132,6 +133,176 @@ func TestVMExecScript(t *testing.T) { require.Equal(t, websocket.StatusNormalClosure, closeError.Code) } +func TestVMExecSessionReconnectHistory(t *testing.T) { + devClient, vmName := prepareForExec(t) + sessionID := uuid.NewString() + + wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{ + Command: "sh -c 'echo first; sleep 1; echo second'", + WaitSeconds: 30, + Session: sessionID, + }) + require.NoError(t, err) + + firstFrame := readFrame(t, wsConn) + require.Equal(t, execstream.FrameTypeStdout, firstFrame.Type) + require.Equal(t, "first\n", string(firstFrame.Data)) + require.EqualValues(t, 1, firstFrame.Watermark) + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach}) + require.NoError(t, err) + _ = wsConn.CloseNow() + + // Let the detached process finish so this test verifies partial replay + // without relying on live-output timing. + time.Sleep(2 * time.Second) + + wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{ + WaitSeconds: 30, + Session: sessionID, + }) + require.NoError(t, err) + defer wsConn.CloseNow() + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{ + Type: execstream.FrameTypeHistory, + Watermark: firstFrame.Watermark, + }) + require.NoError(t, err) + + frames := readFramesUntilExit(t, wsConn) + require.Len(t, framesByType(frames, execstream.FrameTypeStdout), 1) + require.Equal(t, "second\n", string(framesByType(frames, execstream.FrameTypeStdout)[0].Data)) + require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code) +} + +func TestVMExecSessionReconnectAfterExit(t *testing.T) { + devClient, vmName := prepareForExec(t) + sessionID := uuid.NewString() + + wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{ + Command: "sh -c 'echo replay-me'", + WaitSeconds: 30, + Session: sessionID, + }) + require.NoError(t, err) + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach}) + require.NoError(t, err) + _ = wsConn.CloseNow() + + time.Sleep(time.Second) + + wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{ + WaitSeconds: 30, + Session: sessionID, + }) + require.NoError(t, err) + defer wsConn.CloseNow() + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory}) + require.NoError(t, err) + + frames := readFramesUntilExit(t, wsConn) + require.Equal(t, "replay-me\n", string(framesByType(frames, execstream.FrameTypeStdout)[0].Data)) + require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code) +} + +func TestVMExecSessionReplayPreservesStreams(t *testing.T) { + devClient, vmName := prepareForExec(t) + sessionID := uuid.NewString() + + wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{ + Command: "sh -c 'echo out1; sleep 1; echo err1 >&2; sleep 1; echo out2; sleep 1; echo err2 >&2'", + WaitSeconds: 30, + Session: sessionID, + }) + require.NoError(t, err) + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach}) + require.NoError(t, err) + _ = wsConn.CloseNow() + + time.Sleep(4 * time.Second) + + wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{ + WaitSeconds: 30, + Session: sessionID, + }) + require.NoError(t, err) + defer wsConn.CloseNow() + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory}) + require.NoError(t, err) + + frames := readFramesUntilExit(t, wsConn) + require.Equal(t, []execstream.FrameType{ + execstream.FrameTypeStdout, + execstream.FrameTypeStderr, + execstream.FrameTypeStdout, + execstream.FrameTypeStderr, + execstream.FrameTypeExit, + }, frameTypes(frames)) + require.Equal(t, "out1\n", string(frames[0].Data)) + require.Equal(t, "err1\n", string(frames[1].Data)) + require.Equal(t, "out2\n", string(frames[2].Data)) + require.Equal(t, "err2\n", string(frames[3].Data)) +} + +func TestVMExecSessionStdinSurvivesReconnect(t *testing.T) { + devClient, vmName := prepareForExec(t) + sessionID := uuid.NewString() + + wsConn, err := devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{ + Command: "/bin/cat", + Stdin: true, + WaitSeconds: 30, + Session: sessionID, + }) + require.NoError(t, err) + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{ + Type: execstream.FrameTypeStdin, + Data: []byte("one\n"), + }) + require.NoError(t, err) + + frame := readFrame(t, wsConn) + require.Equal(t, execstream.FrameTypeStdout, frame.Type) + require.Equal(t, "one\n", string(frame.Data)) + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeDetach}) + require.NoError(t, err) + _ = wsConn.CloseNow() + + wsConn, err = devClient.VMs().ExecSession(t.Context(), vmName, client.ExecSessionOptions{ + WaitSeconds: 30, + Session: sessionID, + }) + require.NoError(t, err) + defer wsConn.CloseNow() + + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{ + Type: execstream.FrameTypeStdin, + Data: []byte("two\n"), + }) + require.NoError(t, err) + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{ + Type: execstream.FrameTypeStdin, + Data: []byte{}, + }) + require.NoError(t, err) + err = execstream.WriteFrame(t.Context(), wsConn, &execstream.Frame{Type: execstream.FrameTypeHistory}) + require.NoError(t, err) + + frames := readFramesUntilExit(t, wsConn) + stdoutFrames := framesByType(frames, execstream.FrameTypeStdout) + require.Len(t, stdoutFrames, 2) + require.Equal(t, "one\n", string(stdoutFrames[0].Data)) + require.Equal(t, "two\n", string(stdoutFrames[1].Data)) + require.EqualValues(t, 0, framesByType(frames, execstream.FrameTypeExit)[0].Exit.Code) +} + func prepareForExec(t *testing.T) (*client.Client, string) { devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) @@ -153,14 +324,62 @@ func prepareForExec(t *testing.T) (*client.Client, string) { } func readFrame(t *testing.T, wsConn *websocket.Conn) *execstream.Frame { + t.Helper() + var frame execstream.Frame - messageType, payloadBytes, err := wsConn.Read(t.Context()) + readCtx, readCtxCancel := context.WithTimeout(t.Context(), 30*time.Second) + defer readCtxCancel() + + messageType, payloadBytes, err := wsConn.Read(readCtx) require.NoError(t, err) require.Equal(t, websocket.MessageText, messageType) err = json.Unmarshal(payloadBytes, &frame) require.NoError(t, err) + if frame.Type == execstream.FrameTypeError { + require.FailNowf(t, "exec stream error", "%s", frame.Error) + } return &frame } + +func readFramesUntilExit(t *testing.T, wsConn *websocket.Conn) []*execstream.Frame { + t.Helper() + + var frames []*execstream.Frame + + for { + frame := readFrame(t, wsConn) + if frame.Type == execstream.FrameTypeNoMoreHistory { + continue + } + + frames = append(frames, frame) + if frame.Type == execstream.FrameTypeExit { + return frames + } + } +} + +func framesByType(frames []*execstream.Frame, frameType execstream.FrameType) []*execstream.Frame { + var result []*execstream.Frame + + for _, frame := range frames { + if frame.Type == frameType { + result = append(result, frame) + } + } + + return result +} + +func frameTypes(frames []*execstream.Frame) []execstream.FrameType { + var result []execstream.FrameType + + for _, frame := range frames { + result = append(result, frame.Type) + } + + return result +} diff --git a/pkg/client/vms.go b/pkg/client/vms.go index 3bcbb6e..7b73544 100644 --- a/pkg/client/vms.go +++ b/pkg/client/vms.go @@ -47,6 +47,13 @@ type EventsPageOptions struct { Cursor string } +type ExecSessionOptions struct { + Command string + Stdin bool + WaitSeconds uint16 + Session string +} + func (service *VMsService) Create(ctx context.Context, vm *v1.VM) error { err := service.client.request(ctx, http.MethodPost, "vms", vm, nil, nil) @@ -164,12 +171,33 @@ func (service *VMsService) Exec( stdin bool, waitSeconds uint16, ) (*websocket.Conn, error) { + return service.ExecSession(ctx, name, ExecSessionOptions{ + Command: command, + Stdin: stdin, + WaitSeconds: waitSeconds, + }) +} + +func (service *VMsService) ExecSession( + ctx context.Context, + name string, + options ExecSessionOptions, +) (*websocket.Conn, error) { + params := map[string]string{ + "wait": strconv.FormatUint(uint64(options.WaitSeconds), 10), + } + if options.Command != "" { + params["command"] = options.Command + } + if options.Stdin { + params["stdin"] = strconv.FormatBool(true) + } + if options.Session != "" { + params["session"] = options.Session + } + return service.client.wsRequestRaw(ctx, fmt.Sprintf("vms/%s/exec", url.PathEscape(name)), - map[string]string{ - "command": command, - "stdin": strconv.FormatBool(stdin), - "wait": strconv.FormatUint(uint64(waitSeconds), 10), - }) + params) } func (service *VMsService) IP(ctx context.Context, name string, waitSeconds uint16) (string, error) { diff --git a/pkg/client/vms_test.go b/pkg/client/vms_test.go new file mode 100644 index 0000000..79f3394 --- /dev/null +++ b/pkg/client/vms_test.go @@ -0,0 +1,40 @@ +package client + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +func TestExecSessionBuildsReconnectableQuery(t *testing.T) { + var query map[string][]string + + server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + query = request.URL.Query() + + conn, err := websocket.Accept(writer, request, nil) + require.NoError(t, err) + defer conn.CloseNow() + })) + defer server.Close() + + devClient, err := New(WithAddress(server.URL)) + require.NoError(t, err) + + conn, err := devClient.VMs().ExecSession(t.Context(), "vm", ExecSessionOptions{ + Command: "echo hello", + Stdin: true, + WaitSeconds: 7, + Session: "resume-me", + }) + require.NoError(t, err) + defer conn.CloseNow() + + require.Equal(t, []string{"echo hello"}, query["command"]) + require.Equal(t, []string{"true"}, query["stdin"]) + require.Equal(t, []string{"7"}, query["wait"]) + require.Equal(t, []string{"resume-me"}, query["session"]) +}