693 lines
14 KiB
Go
693 lines
14 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cirruslabs/orchard/internal/execstream"
|
|
)
|
|
|
|
const execSessionReplayBufferBytes = 4 * 1024 * 1024
|
|
|
|
type execSessionPolicy struct {
|
|
closeOnDetach bool
|
|
retainAfterExit bool
|
|
replayEnabled bool
|
|
}
|
|
|
|
var (
|
|
legacyExecSessionPolicy = execSessionPolicy{
|
|
closeOnDetach: true,
|
|
}
|
|
reconnectableExecSessionPolicy = execSessionPolicy{
|
|
retainAfterExit: true,
|
|
replayEnabled: true,
|
|
}
|
|
)
|
|
|
|
type sshExecRunner interface {
|
|
Stdin() io.WriteCloser
|
|
Run(ctx context.Context, command string, outgoingFrames chan<- *execstream.Frame) error
|
|
Close() error
|
|
}
|
|
|
|
type execSessionKey struct {
|
|
vmName string
|
|
sessionID string
|
|
}
|
|
|
|
type execSessionCreation struct {
|
|
done chan struct{}
|
|
session *execSession
|
|
err error
|
|
}
|
|
|
|
type execSessionRegistry struct {
|
|
mu sync.Mutex
|
|
sessions map[execSessionKey]*execSession
|
|
creating map[execSessionKey]*execSessionCreation
|
|
}
|
|
|
|
func newExecSessionRegistry() *execSessionRegistry {
|
|
return &execSessionRegistry{
|
|
sessions: map[execSessionKey]*execSession{},
|
|
creating: map[execSessionKey]*execSessionCreation{},
|
|
}
|
|
}
|
|
|
|
func (registry *execSessionRegistry) get(key execSessionKey) (*execSession, bool) {
|
|
registry.mu.Lock()
|
|
defer registry.mu.Unlock()
|
|
|
|
session, ok := registry.sessions[key]
|
|
|
|
return session, ok
|
|
}
|
|
|
|
func (registry *execSessionRegistry) getOrCreate(
|
|
ctx context.Context,
|
|
key execSessionKey,
|
|
create func() (*execSession, error),
|
|
) (*execSession, bool, error) {
|
|
registry.mu.Lock()
|
|
|
|
if session, ok := registry.sessions[key]; ok {
|
|
registry.mu.Unlock()
|
|
|
|
return session, false, nil
|
|
}
|
|
|
|
if creation, ok := registry.creating[key]; ok {
|
|
registry.mu.Unlock()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, false, ctx.Err()
|
|
case <-creation.done:
|
|
if creation.err != nil {
|
|
return nil, false, creation.err
|
|
}
|
|
|
|
return creation.session, false, nil
|
|
}
|
|
}
|
|
|
|
creation := &execSessionCreation{done: make(chan struct{})}
|
|
registry.creating[key] = creation
|
|
registry.mu.Unlock()
|
|
|
|
session, err := create()
|
|
|
|
registry.mu.Lock()
|
|
delete(registry.creating, key)
|
|
if err == nil {
|
|
registry.sessions[key] = session
|
|
}
|
|
creation.session = session
|
|
creation.err = err
|
|
close(creation.done)
|
|
registry.mu.Unlock()
|
|
|
|
return session, true, err
|
|
}
|
|
|
|
func (registry *execSessionRegistry) remove(key execSessionKey, expected *execSession) {
|
|
registry.mu.Lock()
|
|
defer registry.mu.Unlock()
|
|
|
|
if registry.sessions[key] == expected {
|
|
delete(registry.sessions, key)
|
|
}
|
|
}
|
|
|
|
func (registry *execSessionRegistry) closeAll() {
|
|
registry.mu.Lock()
|
|
sessions := make([]*execSession, 0, len(registry.sessions))
|
|
for _, session := range registry.sessions {
|
|
sessions = append(sessions, session)
|
|
}
|
|
registry.mu.Unlock()
|
|
|
|
for _, session := range sessions {
|
|
session.close()
|
|
}
|
|
}
|
|
|
|
type execReplayFrame struct {
|
|
frame *execstream.Frame
|
|
size int
|
|
}
|
|
|
|
type execReplayBuffer struct {
|
|
frames []execReplayFrame
|
|
bufferBytes int
|
|
nextWatermark uint64
|
|
ackedWatermark uint64
|
|
}
|
|
|
|
func (buffer *execReplayBuffer) append(frame *execstream.Frame) *execstream.Frame {
|
|
frame = cloneExecFrame(frame)
|
|
buffer.nextWatermark++
|
|
frame.Watermark = buffer.nextWatermark
|
|
|
|
frameSize := execFrameSize(frame)
|
|
buffer.frames = append(buffer.frames, execReplayFrame{
|
|
frame: frame,
|
|
size: frameSize,
|
|
})
|
|
buffer.bufferBytes += frameSize
|
|
buffer.trimAcknowledged()
|
|
buffer.trimToLimit()
|
|
|
|
return frame
|
|
}
|
|
|
|
func (buffer *execReplayBuffer) ack(watermark uint64) {
|
|
if watermark <= buffer.ackedWatermark {
|
|
return
|
|
}
|
|
|
|
buffer.ackedWatermark = watermark
|
|
buffer.trimAcknowledged()
|
|
}
|
|
|
|
func (buffer *execReplayBuffer) replayAfter(
|
|
watermark uint64,
|
|
frames []*execstream.Frame,
|
|
) []*execstream.Frame {
|
|
for _, record := range buffer.frames {
|
|
if record.frame.Watermark <= watermark {
|
|
continue
|
|
}
|
|
|
|
frames = append(frames, record.frame)
|
|
}
|
|
|
|
return frames
|
|
}
|
|
|
|
func (buffer *execReplayBuffer) trimAcknowledged() {
|
|
for len(buffer.frames) > 0 && buffer.frames[0].frame.Watermark <= buffer.ackedWatermark {
|
|
buffer.bufferBytes -= buffer.frames[0].size
|
|
buffer.frames = buffer.frames[1:]
|
|
}
|
|
}
|
|
|
|
func (buffer *execReplayBuffer) trimToLimit() {
|
|
for buffer.bufferBytes > execSessionReplayBufferBytes && len(buffer.frames) > 0 {
|
|
buffer.bufferBytes -= buffer.frames[0].size
|
|
buffer.frames = buffer.frames[1:]
|
|
}
|
|
}
|
|
|
|
type execSessionSubscriber struct {
|
|
frames chan *execstream.Frame
|
|
closed chan struct{}
|
|
closeOnce sync.Once
|
|
sendMu sync.Mutex
|
|
sentWatermark uint64
|
|
}
|
|
|
|
func newExecSessionSubscriber() *execSessionSubscriber {
|
|
return &execSessionSubscriber{
|
|
frames: make(chan *execstream.Frame, 128),
|
|
closed: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (subscriber *execSessionSubscriber) enqueue(frame *execstream.Frame) bool {
|
|
subscriber.sendMu.Lock()
|
|
defer subscriber.sendMu.Unlock()
|
|
|
|
if subscriber.alreadySentLocked(frame) {
|
|
return true
|
|
}
|
|
|
|
select {
|
|
case <-subscriber.closed:
|
|
return false
|
|
default:
|
|
}
|
|
|
|
select {
|
|
case subscriber.frames <- subscriber.markSentLocked(frame):
|
|
return true
|
|
case <-subscriber.closed:
|
|
return false
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (subscriber *execSessionSubscriber) sendHistory(frames []*execstream.Frame) bool {
|
|
for _, frame := range frames {
|
|
if !subscriber.sendLocked(frame) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (subscriber *execSessionSubscriber) sendLocked(frame *execstream.Frame) bool {
|
|
if subscriber.alreadySentLocked(frame) {
|
|
return true
|
|
}
|
|
|
|
select {
|
|
case <-subscriber.closed:
|
|
return false
|
|
default:
|
|
}
|
|
|
|
select {
|
|
case subscriber.frames <- subscriber.markSentLocked(frame):
|
|
return true
|
|
case <-subscriber.closed:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (subscriber *execSessionSubscriber) alreadySentLocked(frame *execstream.Frame) bool {
|
|
return isReplayOutputFrame(frame) &&
|
|
frame.Watermark != 0 &&
|
|
frame.Watermark <= subscriber.sentWatermark
|
|
}
|
|
|
|
func (subscriber *execSessionSubscriber) markSentLocked(frame *execstream.Frame) *execstream.Frame {
|
|
frame = cloneExecFrame(frame)
|
|
if isReplayOutputFrame(frame) && frame.Watermark > subscriber.sentWatermark {
|
|
subscriber.sentWatermark = frame.Watermark
|
|
}
|
|
|
|
return frame
|
|
}
|
|
|
|
func (subscriber *execSessionSubscriber) close() {
|
|
subscriber.closeOnce.Do(func() {
|
|
close(subscriber.closed)
|
|
subscriber.sendMu.Lock()
|
|
close(subscriber.frames)
|
|
subscriber.sendMu.Unlock()
|
|
})
|
|
}
|
|
|
|
type execSession struct {
|
|
key execSessionKey
|
|
command string
|
|
exec sshExecRunner
|
|
transport net.Conn
|
|
registry *execSessionRegistry
|
|
exitTTL time.Duration
|
|
policy execSessionPolicy
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
mu sync.Mutex
|
|
stdin io.WriteCloser
|
|
stdinClosed bool
|
|
subscribers map[*execSessionSubscriber]struct{}
|
|
replay execReplayBuffer
|
|
started bool
|
|
finished bool
|
|
closed bool
|
|
expiryTimer *time.Timer
|
|
|
|
startOnce sync.Once
|
|
done chan struct{}
|
|
doneOnce sync.Once
|
|
}
|
|
|
|
func newExecSession(
|
|
key execSessionKey,
|
|
command string,
|
|
exec sshExecRunner,
|
|
transport net.Conn,
|
|
registry *execSessionRegistry,
|
|
exitTTL time.Duration,
|
|
policy execSessionPolicy,
|
|
) *execSession {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
return newExecSessionWithContext(
|
|
ctx,
|
|
cancel,
|
|
key,
|
|
command,
|
|
exec,
|
|
transport,
|
|
registry,
|
|
exitTTL,
|
|
policy,
|
|
)
|
|
}
|
|
|
|
func newExecSessionWithContext(
|
|
ctx context.Context,
|
|
cancel context.CancelFunc,
|
|
key execSessionKey,
|
|
command string,
|
|
exec sshExecRunner,
|
|
transport net.Conn,
|
|
registry *execSessionRegistry,
|
|
exitTTL time.Duration,
|
|
policy execSessionPolicy,
|
|
) *execSession {
|
|
if ctx == nil || cancel == nil {
|
|
ctx, cancel = context.WithCancel(context.Background())
|
|
}
|
|
|
|
session := &execSession{
|
|
key: key,
|
|
command: command,
|
|
exec: exec,
|
|
transport: transport,
|
|
registry: registry,
|
|
exitTTL: exitTTL,
|
|
policy: policy,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
stdin: exec.Stdin(),
|
|
subscribers: map[*execSessionSubscriber]struct{}{},
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
return session
|
|
}
|
|
|
|
func (session *execSession) commandMatches(command string) bool {
|
|
return command == "" || session.command == command
|
|
}
|
|
|
|
func (session *execSession) start() {
|
|
session.startOnce.Do(func() {
|
|
session.mu.Lock()
|
|
if session.closed {
|
|
session.mu.Unlock()
|
|
|
|
return
|
|
}
|
|
session.started = true
|
|
session.mu.Unlock()
|
|
|
|
go session.run()
|
|
})
|
|
}
|
|
|
|
func (session *execSession) closeIfUnused() {
|
|
session.mu.Lock()
|
|
unused := !session.started && len(session.subscribers) == 0
|
|
session.mu.Unlock()
|
|
|
|
if unused {
|
|
session.close()
|
|
}
|
|
}
|
|
|
|
func (session *execSession) attach() (*execSessionSubscriber, error) {
|
|
session.mu.Lock()
|
|
defer session.mu.Unlock()
|
|
|
|
if session.closed {
|
|
return nil, errors.New("exec session is closed")
|
|
}
|
|
|
|
subscriber := newExecSessionSubscriber()
|
|
session.subscribers[subscriber] = struct{}{}
|
|
|
|
return subscriber, nil
|
|
}
|
|
|
|
func (session *execSession) detach(subscriber *execSessionSubscriber) {
|
|
if session.policy.closeOnDetach {
|
|
session.close()
|
|
|
|
return
|
|
}
|
|
|
|
session.mu.Lock()
|
|
defer session.mu.Unlock()
|
|
|
|
session.detachLocked(subscriber)
|
|
}
|
|
|
|
func (session *execSession) detachLocked(subscriber *execSessionSubscriber) {
|
|
if _, ok := session.subscribers[subscriber]; !ok {
|
|
return
|
|
}
|
|
|
|
delete(session.subscribers, subscriber)
|
|
subscriber.close()
|
|
}
|
|
|
|
func (session *execSession) writeStdin(data []byte) error {
|
|
session.mu.Lock()
|
|
defer session.mu.Unlock()
|
|
|
|
if session.stdin == nil || session.stdinClosed {
|
|
return errors.New("this exec session has no stdin enabled or it is already closed")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
if err := session.stdin.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
session.stdinClosed = true
|
|
|
|
return nil
|
|
}
|
|
|
|
_, err := session.stdin.Write(data)
|
|
|
|
return err
|
|
}
|
|
|
|
func (session *execSession) ack(watermark uint64) {
|
|
if !session.policy.replayEnabled {
|
|
return
|
|
}
|
|
|
|
session.mu.Lock()
|
|
defer session.mu.Unlock()
|
|
|
|
session.replay.ack(watermark)
|
|
}
|
|
|
|
func (session *execSession) sendHistory(
|
|
subscriber *execSessionSubscriber,
|
|
watermark uint64,
|
|
) {
|
|
if !session.policy.replayEnabled {
|
|
return
|
|
}
|
|
|
|
session.mu.Lock()
|
|
|
|
if _, ok := session.subscribers[subscriber]; !ok {
|
|
session.mu.Unlock()
|
|
|
|
return
|
|
}
|
|
|
|
subscriber.sendMu.Lock()
|
|
frames := session.replay.replayAfter(watermark, nil)
|
|
frames = append(frames, &execstream.Frame{
|
|
Type: execstream.FrameTypeNoMoreHistory,
|
|
Watermark: session.replay.nextWatermark,
|
|
})
|
|
session.mu.Unlock()
|
|
|
|
ok := subscriber.sendHistory(frames)
|
|
subscriber.sendMu.Unlock()
|
|
|
|
if !ok {
|
|
session.dropSubscriber(subscriber)
|
|
}
|
|
}
|
|
|
|
func (session *execSession) close() {
|
|
session.mu.Lock()
|
|
if session.closed {
|
|
session.mu.Unlock()
|
|
|
|
return
|
|
}
|
|
|
|
session.closed = true
|
|
if session.expiryTimer != nil {
|
|
session.expiryTimer.Stop()
|
|
session.expiryTimer = nil
|
|
}
|
|
|
|
subscribers := session.takeSubscribersLocked()
|
|
session.mu.Unlock()
|
|
|
|
closeSubscribers(subscribers)
|
|
|
|
session.cancel()
|
|
_ = session.exec.Close()
|
|
if session.transport != nil {
|
|
_ = session.transport.Close()
|
|
}
|
|
if session.registry != nil {
|
|
session.registry.remove(session.key, session)
|
|
}
|
|
}
|
|
|
|
func (session *execSession) run() {
|
|
outgoingFrames := make(chan *execstream.Frame)
|
|
runErrCh := make(chan error, 1)
|
|
|
|
go func() {
|
|
runErrCh <- session.exec.Run(session.ctx, session.command, outgoingFrames)
|
|
close(outgoingFrames)
|
|
}()
|
|
|
|
for frame := range outgoingFrames {
|
|
session.recordFrame(frame)
|
|
}
|
|
|
|
runErr := <-runErrCh
|
|
if runErr != nil && !errors.Is(runErr, context.Canceled) {
|
|
session.recordFrame(&execstream.Frame{
|
|
Type: execstream.FrameTypeError,
|
|
Error: runErr.Error(),
|
|
})
|
|
}
|
|
|
|
session.markFinished()
|
|
}
|
|
|
|
func (session *execSession) recordFrame(frame *execstream.Frame) {
|
|
session.mu.Lock()
|
|
|
|
if session.closed {
|
|
session.mu.Unlock()
|
|
|
|
return
|
|
}
|
|
|
|
if session.policy.replayEnabled {
|
|
frame = session.replay.append(frame)
|
|
} else {
|
|
frame = cloneExecFrame(frame)
|
|
}
|
|
|
|
subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers))
|
|
for subscriber := range session.subscribers {
|
|
subscribers = append(subscribers, subscriber)
|
|
}
|
|
session.mu.Unlock()
|
|
|
|
for _, subscriber := range subscribers {
|
|
if !subscriber.enqueue(frame) {
|
|
session.dropSubscriber(subscriber)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (session *execSession) markFinished() {
|
|
session.mu.Lock()
|
|
if session.finished {
|
|
session.mu.Unlock()
|
|
|
|
return
|
|
}
|
|
|
|
session.finished = true
|
|
shouldClose := !session.policy.retainAfterExit
|
|
if !session.closed && session.policy.retainAfterExit {
|
|
session.expiryTimer = time.AfterFunc(session.exitTTL, session.expire)
|
|
}
|
|
|
|
var subscribers []*execSessionSubscriber
|
|
if shouldClose {
|
|
subscribers = session.takeSubscribersLocked()
|
|
}
|
|
session.mu.Unlock()
|
|
|
|
closeSubscribers(subscribers)
|
|
|
|
session.doneOnce.Do(func() {
|
|
close(session.done)
|
|
})
|
|
|
|
if shouldClose {
|
|
session.close()
|
|
}
|
|
}
|
|
|
|
func (session *execSession) expire() {
|
|
session.close()
|
|
}
|
|
|
|
func (session *execSession) takeSubscribersLocked() []*execSessionSubscriber {
|
|
subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers))
|
|
for subscriber := range session.subscribers {
|
|
subscribers = append(subscribers, subscriber)
|
|
}
|
|
session.subscribers = map[*execSessionSubscriber]struct{}{}
|
|
|
|
return subscribers
|
|
}
|
|
|
|
func closeSubscribers(subscribers []*execSessionSubscriber) {
|
|
for _, subscriber := range subscribers {
|
|
subscriber.close()
|
|
}
|
|
}
|
|
|
|
func (session *execSession) dropSubscriber(subscriber *execSessionSubscriber) {
|
|
session.mu.Lock()
|
|
defer session.mu.Unlock()
|
|
|
|
session.detachLocked(subscriber)
|
|
}
|
|
|
|
func cloneExecFrame(frame *execstream.Frame) *execstream.Frame {
|
|
if frame == nil {
|
|
return nil
|
|
}
|
|
|
|
clone := *frame
|
|
if frame.Data != nil {
|
|
clone.Data = append([]byte(nil), frame.Data...)
|
|
}
|
|
if frame.Exit != nil {
|
|
exit := *frame.Exit
|
|
clone.Exit = &exit
|
|
}
|
|
|
|
return &clone
|
|
}
|
|
|
|
func execFrameSize(frame *execstream.Frame) int {
|
|
if frame == nil {
|
|
return 0
|
|
}
|
|
|
|
return len(frame.Data) + len(frame.Error) + 16
|
|
}
|
|
|
|
func isReplayOutputFrame(frame *execstream.Frame) bool {
|
|
if frame == nil {
|
|
return false
|
|
}
|
|
|
|
switch frame.Type {
|
|
case execstream.FrameTypeStdout,
|
|
execstream.FrameTypeStderr,
|
|
execstream.FrameTypeExit,
|
|
execstream.FrameTypeError:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|