orchard/internal/controller/exec_sessions.go

693 lines
14 KiB
Go

package controller
import (
"context"
"errors"
"io"
"net"
"sync"
"time"
"github.com/cirruslabs/orchard/internal/execstream"
)
const execSessionReplayBufferBytes = 4 * 1024 * 1024
type execSessionPolicy struct {
closeOnDetach bool
retainAfterExit bool
replayEnabled bool
}
var (
legacyExecSessionPolicy = execSessionPolicy{
closeOnDetach: true,
}
reconnectableExecSessionPolicy = execSessionPolicy{
retainAfterExit: true,
replayEnabled: true,
}
)
type sshExecRunner interface {
Stdin() io.WriteCloser
Run(ctx context.Context, command string, outgoingFrames chan<- *execstream.Frame) error
Close() error
}
type execSessionKey struct {
vmName string
sessionID string
}
type execSessionCreation struct {
done chan struct{}
session *execSession
err error
}
type execSessionRegistry struct {
mu sync.Mutex
sessions map[execSessionKey]*execSession
creating map[execSessionKey]*execSessionCreation
}
func newExecSessionRegistry() *execSessionRegistry {
return &execSessionRegistry{
sessions: map[execSessionKey]*execSession{},
creating: map[execSessionKey]*execSessionCreation{},
}
}
func (registry *execSessionRegistry) get(key execSessionKey) (*execSession, bool) {
registry.mu.Lock()
defer registry.mu.Unlock()
session, ok := registry.sessions[key]
return session, ok
}
func (registry *execSessionRegistry) getOrCreate(
ctx context.Context,
key execSessionKey,
create func() (*execSession, error),
) (*execSession, bool, error) {
registry.mu.Lock()
if session, ok := registry.sessions[key]; ok {
registry.mu.Unlock()
return session, false, nil
}
if creation, ok := registry.creating[key]; ok {
registry.mu.Unlock()
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case <-creation.done:
if creation.err != nil {
return nil, false, creation.err
}
return creation.session, false, nil
}
}
creation := &execSessionCreation{done: make(chan struct{})}
registry.creating[key] = creation
registry.mu.Unlock()
session, err := create()
registry.mu.Lock()
delete(registry.creating, key)
if err == nil {
registry.sessions[key] = session
}
creation.session = session
creation.err = err
close(creation.done)
registry.mu.Unlock()
return session, true, err
}
func (registry *execSessionRegistry) remove(key execSessionKey, expected *execSession) {
registry.mu.Lock()
defer registry.mu.Unlock()
if registry.sessions[key] == expected {
delete(registry.sessions, key)
}
}
func (registry *execSessionRegistry) closeAll() {
registry.mu.Lock()
sessions := make([]*execSession, 0, len(registry.sessions))
for _, session := range registry.sessions {
sessions = append(sessions, session)
}
registry.mu.Unlock()
for _, session := range sessions {
session.close()
}
}
type execReplayFrame struct {
frame *execstream.Frame
size int
}
type execReplayBuffer struct {
frames []execReplayFrame
bufferBytes int
nextWatermark uint64
ackedWatermark uint64
}
func (buffer *execReplayBuffer) append(frame *execstream.Frame) *execstream.Frame {
frame = cloneExecFrame(frame)
buffer.nextWatermark++
frame.Watermark = buffer.nextWatermark
frameSize := execFrameSize(frame)
buffer.frames = append(buffer.frames, execReplayFrame{
frame: frame,
size: frameSize,
})
buffer.bufferBytes += frameSize
buffer.trimAcknowledged()
buffer.trimToLimit()
return frame
}
func (buffer *execReplayBuffer) ack(watermark uint64) {
if watermark <= buffer.ackedWatermark {
return
}
buffer.ackedWatermark = watermark
buffer.trimAcknowledged()
}
func (buffer *execReplayBuffer) replayAfter(
watermark uint64,
frames []*execstream.Frame,
) []*execstream.Frame {
for _, record := range buffer.frames {
if record.frame.Watermark <= watermark {
continue
}
frames = append(frames, record.frame)
}
return frames
}
func (buffer *execReplayBuffer) trimAcknowledged() {
for len(buffer.frames) > 0 && buffer.frames[0].frame.Watermark <= buffer.ackedWatermark {
buffer.bufferBytes -= buffer.frames[0].size
buffer.frames = buffer.frames[1:]
}
}
func (buffer *execReplayBuffer) trimToLimit() {
for buffer.bufferBytes > execSessionReplayBufferBytes && len(buffer.frames) > 0 {
buffer.bufferBytes -= buffer.frames[0].size
buffer.frames = buffer.frames[1:]
}
}
type execSessionSubscriber struct {
frames chan *execstream.Frame
closed chan struct{}
closeOnce sync.Once
sendMu sync.Mutex
sentWatermark uint64
}
func newExecSessionSubscriber() *execSessionSubscriber {
return &execSessionSubscriber{
frames: make(chan *execstream.Frame, 128),
closed: make(chan struct{}),
}
}
func (subscriber *execSessionSubscriber) enqueue(frame *execstream.Frame) bool {
subscriber.sendMu.Lock()
defer subscriber.sendMu.Unlock()
if subscriber.alreadySentLocked(frame) {
return true
}
select {
case <-subscriber.closed:
return false
default:
}
select {
case subscriber.frames <- subscriber.markSentLocked(frame):
return true
case <-subscriber.closed:
return false
default:
return false
}
}
func (subscriber *execSessionSubscriber) sendHistory(frames []*execstream.Frame) bool {
for _, frame := range frames {
if !subscriber.sendLocked(frame) {
return false
}
}
return true
}
func (subscriber *execSessionSubscriber) sendLocked(frame *execstream.Frame) bool {
if subscriber.alreadySentLocked(frame) {
return true
}
select {
case <-subscriber.closed:
return false
default:
}
select {
case subscriber.frames <- subscriber.markSentLocked(frame):
return true
case <-subscriber.closed:
return false
}
}
func (subscriber *execSessionSubscriber) alreadySentLocked(frame *execstream.Frame) bool {
return isReplayOutputFrame(frame) &&
frame.Watermark != 0 &&
frame.Watermark <= subscriber.sentWatermark
}
func (subscriber *execSessionSubscriber) markSentLocked(frame *execstream.Frame) *execstream.Frame {
frame = cloneExecFrame(frame)
if isReplayOutputFrame(frame) && frame.Watermark > subscriber.sentWatermark {
subscriber.sentWatermark = frame.Watermark
}
return frame
}
func (subscriber *execSessionSubscriber) close() {
subscriber.closeOnce.Do(func() {
close(subscriber.closed)
subscriber.sendMu.Lock()
close(subscriber.frames)
subscriber.sendMu.Unlock()
})
}
type execSession struct {
key execSessionKey
command string
exec sshExecRunner
transport net.Conn
registry *execSessionRegistry
exitTTL time.Duration
policy execSessionPolicy
ctx context.Context
cancel context.CancelFunc
mu sync.Mutex
stdin io.WriteCloser
stdinClosed bool
subscribers map[*execSessionSubscriber]struct{}
replay execReplayBuffer
started bool
finished bool
closed bool
expiryTimer *time.Timer
startOnce sync.Once
done chan struct{}
doneOnce sync.Once
}
func newExecSession(
key execSessionKey,
command string,
exec sshExecRunner,
transport net.Conn,
registry *execSessionRegistry,
exitTTL time.Duration,
policy execSessionPolicy,
) *execSession {
ctx, cancel := context.WithCancel(context.Background())
return newExecSessionWithContext(
ctx,
cancel,
key,
command,
exec,
transport,
registry,
exitTTL,
policy,
)
}
func newExecSessionWithContext(
ctx context.Context,
cancel context.CancelFunc,
key execSessionKey,
command string,
exec sshExecRunner,
transport net.Conn,
registry *execSessionRegistry,
exitTTL time.Duration,
policy execSessionPolicy,
) *execSession {
if ctx == nil || cancel == nil {
ctx, cancel = context.WithCancel(context.Background())
}
session := &execSession{
key: key,
command: command,
exec: exec,
transport: transport,
registry: registry,
exitTTL: exitTTL,
policy: policy,
ctx: ctx,
cancel: cancel,
stdin: exec.Stdin(),
subscribers: map[*execSessionSubscriber]struct{}{},
done: make(chan struct{}),
}
return session
}
func (session *execSession) commandMatches(command string) bool {
return command == "" || session.command == command
}
func (session *execSession) start() {
session.startOnce.Do(func() {
session.mu.Lock()
if session.closed {
session.mu.Unlock()
return
}
session.started = true
session.mu.Unlock()
go session.run()
})
}
func (session *execSession) closeIfUnused() {
session.mu.Lock()
unused := !session.started && len(session.subscribers) == 0
session.mu.Unlock()
if unused {
session.close()
}
}
func (session *execSession) attach() (*execSessionSubscriber, error) {
session.mu.Lock()
defer session.mu.Unlock()
if session.closed {
return nil, errors.New("exec session is closed")
}
subscriber := newExecSessionSubscriber()
session.subscribers[subscriber] = struct{}{}
return subscriber, nil
}
func (session *execSession) detach(subscriber *execSessionSubscriber) {
if session.policy.closeOnDetach {
session.close()
return
}
session.mu.Lock()
defer session.mu.Unlock()
session.detachLocked(subscriber)
}
func (session *execSession) detachLocked(subscriber *execSessionSubscriber) {
if _, ok := session.subscribers[subscriber]; !ok {
return
}
delete(session.subscribers, subscriber)
subscriber.close()
}
func (session *execSession) writeStdin(data []byte) error {
session.mu.Lock()
defer session.mu.Unlock()
if session.stdin == nil || session.stdinClosed {
return errors.New("this exec session has no stdin enabled or it is already closed")
}
if len(data) == 0 {
if err := session.stdin.Close(); err != nil {
return err
}
session.stdinClosed = true
return nil
}
_, err := session.stdin.Write(data)
return err
}
func (session *execSession) ack(watermark uint64) {
if !session.policy.replayEnabled {
return
}
session.mu.Lock()
defer session.mu.Unlock()
session.replay.ack(watermark)
}
func (session *execSession) sendHistory(
subscriber *execSessionSubscriber,
watermark uint64,
) {
if !session.policy.replayEnabled {
return
}
session.mu.Lock()
if _, ok := session.subscribers[subscriber]; !ok {
session.mu.Unlock()
return
}
subscriber.sendMu.Lock()
frames := session.replay.replayAfter(watermark, nil)
frames = append(frames, &execstream.Frame{
Type: execstream.FrameTypeNoMoreHistory,
Watermark: session.replay.nextWatermark,
})
session.mu.Unlock()
ok := subscriber.sendHistory(frames)
subscriber.sendMu.Unlock()
if !ok {
session.dropSubscriber(subscriber)
}
}
func (session *execSession) close() {
session.mu.Lock()
if session.closed {
session.mu.Unlock()
return
}
session.closed = true
if session.expiryTimer != nil {
session.expiryTimer.Stop()
session.expiryTimer = nil
}
subscribers := session.takeSubscribersLocked()
session.mu.Unlock()
closeSubscribers(subscribers)
session.cancel()
_ = session.exec.Close()
if session.transport != nil {
_ = session.transport.Close()
}
if session.registry != nil {
session.registry.remove(session.key, session)
}
}
func (session *execSession) run() {
outgoingFrames := make(chan *execstream.Frame)
runErrCh := make(chan error, 1)
go func() {
runErrCh <- session.exec.Run(session.ctx, session.command, outgoingFrames)
close(outgoingFrames)
}()
for frame := range outgoingFrames {
session.recordFrame(frame)
}
runErr := <-runErrCh
if runErr != nil && !errors.Is(runErr, context.Canceled) {
session.recordFrame(&execstream.Frame{
Type: execstream.FrameTypeError,
Error: runErr.Error(),
})
}
session.markFinished()
}
func (session *execSession) recordFrame(frame *execstream.Frame) {
session.mu.Lock()
if session.closed {
session.mu.Unlock()
return
}
if session.policy.replayEnabled {
frame = session.replay.append(frame)
} else {
frame = cloneExecFrame(frame)
}
subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers))
for subscriber := range session.subscribers {
subscribers = append(subscribers, subscriber)
}
session.mu.Unlock()
for _, subscriber := range subscribers {
if !subscriber.enqueue(frame) {
session.dropSubscriber(subscriber)
}
}
}
func (session *execSession) markFinished() {
session.mu.Lock()
if session.finished {
session.mu.Unlock()
return
}
session.finished = true
shouldClose := !session.policy.retainAfterExit
if !session.closed && session.policy.retainAfterExit {
session.expiryTimer = time.AfterFunc(session.exitTTL, session.expire)
}
var subscribers []*execSessionSubscriber
if shouldClose {
subscribers = session.takeSubscribersLocked()
}
session.mu.Unlock()
closeSubscribers(subscribers)
session.doneOnce.Do(func() {
close(session.done)
})
if shouldClose {
session.close()
}
}
func (session *execSession) expire() {
session.close()
}
func (session *execSession) takeSubscribersLocked() []*execSessionSubscriber {
subscribers := make([]*execSessionSubscriber, 0, len(session.subscribers))
for subscriber := range session.subscribers {
subscribers = append(subscribers, subscriber)
}
session.subscribers = map[*execSessionSubscriber]struct{}{}
return subscribers
}
func closeSubscribers(subscribers []*execSessionSubscriber) {
for _, subscriber := range subscribers {
subscriber.close()
}
}
func (session *execSession) dropSubscriber(subscriber *execSessionSubscriber) {
session.mu.Lock()
defer session.mu.Unlock()
session.detachLocked(subscriber)
}
func cloneExecFrame(frame *execstream.Frame) *execstream.Frame {
if frame == nil {
return nil
}
clone := *frame
if frame.Data != nil {
clone.Data = append([]byte(nil), frame.Data...)
}
if frame.Exit != nil {
exit := *frame.Exit
clone.Exit = &exit
}
return &clone
}
func execFrameSize(frame *execstream.Frame) int {
if frame == nil {
return 0
}
return len(frame.Data) + len(frame.Error) + 16
}
func isReplayOutputFrame(frame *execstream.Frame) bool {
if frame == nil {
return false
}
switch frame.Type {
case execstream.FrameTypeStdout,
execstream.FrameTypeStderr,
execstream.FrameTypeExit,
execstream.FrameTypeError:
return true
default:
return false
}
}