orchard/internal/worker/vmmanager/base/base.go

285 lines
6.6 KiB
Go

package base
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/avast/retry-go/v4"
"github.com/cirruslabs/orchard/internal/dialer"
"github.com/cirruslabs/orchard/pkg/client"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
mapset "github.com/deckarep/golang-set/v2"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
)
var ErrVMFailed = errors.New("VM failed")
type VM struct {
// Backward compatibility with v1.VM specification's "Status" field
//
// "started" is always true after the first "tart run",
// whereas ConditionReady can be used to tell if a VM
// is really running or not.
started atomic.Bool
// A more orthogonal alternative to v1.VM specification's "Status" field,
// which allows a VM to have more than one state.
//
// For example, a VM can be both in ConditionReady and ConditionSuspending/
// ConditionStopping states for a short time. This way in run() we know
// that we're in a process of rebooting a VM, so we can avoid throwing
// an error about unexpected VM termination.
conditions mapset.Set[v1.ConditionType]
statusMessage atomic.Pointer[string]
err atomic.Pointer[error]
logger *zap.SugaredLogger
}
func NewVM(logger *zap.SugaredLogger) *VM {
return &VM{
conditions: mapset.NewSet(v1.ConditionTypeCloning),
logger: logger,
}
}
func (vm *VM) SetStarted(val bool) {
vm.started.Store(val)
}
func (vm *VM) Status() v1.VMStatus {
if vm.Err() != nil {
return v1.VMStatusFailed
}
if vm.started.Load() {
return v1.VMStatusRunning
}
return v1.VMStatusPending
}
func (vm *VM) StatusMessage() string {
status := vm.statusMessage.Load()
if status != nil {
return *status
}
return ""
}
func (vm *VM) SetStatusMessage(status string) {
vm.logger.Debugf(status)
vm.statusMessage.Store(&status)
}
func (vm *VM) Err() error {
if err := vm.err.Load(); err != nil {
return *err
}
return nil
}
func (vm *VM) SetErr(err error) {
if vm.err.CompareAndSwap(nil, &err) {
vm.logger.Error(err)
}
}
func (vm *VM) ConditionsSet() mapset.Set[v1.ConditionType] {
return vm.conditions
}
func (vm *VM) Conditions() []v1.Condition {
// Only expose a minimum amount of conditions necessary
// for the Orchard Controller to make decisions
return []v1.Condition{
vm.conditionTypeToCondition(v1.ConditionTypeRunning),
}
}
func (vm *VM) conditionTypeToCondition(conditionType v1.ConditionType) v1.Condition {
var conditionState v1.ConditionState
if vm.ConditionsSet().ContainsOne(conditionType) {
conditionState = v1.ConditionStateTrue
} else {
conditionState = v1.ConditionStateFalse
}
return v1.Condition{
Type: conditionType,
State: conditionState,
}
}
func (vm *VM) Shell(
ctx context.Context,
sshUser string,
sshPassword string,
script string,
env map[string]string,
consumeLine func(line string),
dialer dialer.Dialer,
getIP func(ctx context.Context) (string, error),
) error {
var sess *ssh.Session
// Set default user and password if not provided
if sshUser == "" && sshPassword == "" {
sshUser = "admin"
sshPassword = "admin"
}
// Configure SSH client
sshConfig := &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
User: sshUser,
Auth: []ssh.AuthMethod{
ssh.Password(sshPassword),
},
}
if err := retry.Do(func() error {
ip, err := getIP(ctx)
if err != nil {
return fmt.Errorf("failed to get VM's IP: %w", err)
}
addr := ip + ":22"
dialCtx, dialCtxCancel := context.WithTimeout(ctx, 5*time.Second)
defer dialCtxCancel()
var netConn net.Conn
if dialer != nil {
netConn, err = dialer.DialContext(dialCtx, "tcp", addr)
} else {
dialer := net.Dialer{}
netConn, err = dialer.DialContext(dialCtx, "tcp", addr)
}
if err != nil {
return fmt.Errorf("failed to dial %s: %w", addr, err)
}
sshConn, chans, reqs, err := ssh.NewClientConn(netConn, addr, sshConfig)
if err != nil {
return fmt.Errorf("SSH handshake with %s failed: %w", addr, err)
}
sshClient := ssh.NewClient(sshConn, chans, reqs)
sess, err = sshClient.NewSession()
if err != nil {
return fmt.Errorf("failed to open an SSH session on %s: %w", addr, err)
}
return nil
}, retry.Context(ctx), retry.OnRetry(func(n uint, err error) {
consumeLine(fmt.Sprintf("attempt %d to establish SSH connection failed: %v", n, err))
})); err != nil {
return fmt.Errorf("failed to establish SSH connection: %w", err)
}
// Log output from the virtual machine
stdout, err := sess.StdoutPipe()
if err != nil {
return fmt.Errorf("%w: while opening stdout pipe: %v", ErrVMFailed, err)
}
stderr, err := sess.StderrPipe()
if err != nil {
return fmt.Errorf("%w: while opening stderr pipe: %v", ErrVMFailed, err)
}
var outputReaderWG sync.WaitGroup
outputReaderWG.Add(1)
go func() {
output := io.MultiReader(stdout, stderr)
scanner := bufio.NewScanner(output)
for scanner.Scan() {
consumeLine(scanner.Text())
}
outputReaderWG.Done()
}()
stdinBuf, err := sess.StdinPipe()
if err != nil {
return fmt.Errorf("%w: while opening stdin pipe: %v", ErrVMFailed, err)
}
// start a login shell so all the customization from ~/.zprofile will be picked up
err = sess.Shell()
if err != nil {
return fmt.Errorf("%w: failed to start a shell: %v", ErrVMFailed, err)
}
var scriptBuilder strings.Builder
scriptBuilder.WriteString("set -e\n")
// don't use sess.Setenv since it requires non-default SSH server configuration
for key, value := range env {
scriptBuilder.WriteString("export " + key + "=\"" + value + "\"\n")
}
scriptBuilder.WriteString(script)
scriptBuilder.WriteString("\nexit\n")
_, err = stdinBuf.Write([]byte(scriptBuilder.String()))
if err != nil {
return fmt.Errorf("%w: failed to start script: %v", ErrVMFailed, err)
}
outputReaderWG.Wait()
return sess.Wait()
}
func (vm *VM) RunScript(
ctx context.Context,
sshUser string,
sshPassword string,
script *v1.VMScript,
eventStreamer *client.EventStreamer,
dialer dialer.Dialer,
getIP func(ctx context.Context) (string, error),
) {
if eventStreamer != nil {
defer func() {
if err := eventStreamer.Close(); err != nil {
vm.logger.Errorf("errored during streaming events for startup script: %v", err)
}
}()
}
consumeLine := func(line string) {
if eventStreamer == nil {
return
}
eventStreamer.Stream(v1.Event{
Kind: v1.EventKindLogLine,
Timestamp: time.Now().Unix(),
Payload: line,
})
}
err := vm.Shell(ctx, sshUser, sshPassword, script.ScriptContent, script.Env, consumeLine, dialer, getIP)
if err != nil {
vm.SetErr(fmt.Errorf("%w: failed to run startup script: %v", ErrVMFailed, err))
}
}