285 lines
6.6 KiB
Go
285 lines
6.6 KiB
Go
package base
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/avast/retry-go/v4"
|
|
"github.com/cirruslabs/orchard/internal/dialer"
|
|
"github.com/cirruslabs/orchard/pkg/client"
|
|
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
mapset "github.com/deckarep/golang-set/v2"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
var ErrVMFailed = errors.New("VM failed")
|
|
|
|
type VM struct {
|
|
// Backward compatibility with v1.VM specification's "Status" field
|
|
//
|
|
// "started" is always true after the first "tart run",
|
|
// whereas ConditionReady can be used to tell if a VM
|
|
// is really running or not.
|
|
started atomic.Bool
|
|
|
|
// A more orthogonal alternative to v1.VM specification's "Status" field,
|
|
// which allows a VM to have more than one state.
|
|
//
|
|
// For example, a VM can be both in ConditionReady and ConditionSuspending/
|
|
// ConditionStopping states for a short time. This way in run() we know
|
|
// that we're in a process of rebooting a VM, so we can avoid throwing
|
|
// an error about unexpected VM termination.
|
|
conditions mapset.Set[v1.ConditionType]
|
|
|
|
statusMessage atomic.Pointer[string]
|
|
err atomic.Pointer[error]
|
|
|
|
logger *zap.SugaredLogger
|
|
}
|
|
|
|
func NewVM(logger *zap.SugaredLogger) *VM {
|
|
return &VM{
|
|
conditions: mapset.NewSet(v1.ConditionTypeCloning),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (vm *VM) SetStarted(val bool) {
|
|
vm.started.Store(val)
|
|
}
|
|
|
|
func (vm *VM) Status() v1.VMStatus {
|
|
if vm.Err() != nil {
|
|
return v1.VMStatusFailed
|
|
}
|
|
|
|
if vm.started.Load() {
|
|
return v1.VMStatusRunning
|
|
}
|
|
|
|
return v1.VMStatusPending
|
|
}
|
|
|
|
func (vm *VM) StatusMessage() string {
|
|
status := vm.statusMessage.Load()
|
|
|
|
if status != nil {
|
|
return *status
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (vm *VM) SetStatusMessage(status string) {
|
|
vm.logger.Debugf(status)
|
|
vm.statusMessage.Store(&status)
|
|
}
|
|
|
|
func (vm *VM) Err() error {
|
|
if err := vm.err.Load(); err != nil {
|
|
return *err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (vm *VM) SetErr(err error) {
|
|
if vm.err.CompareAndSwap(nil, &err) {
|
|
vm.logger.Error(err)
|
|
}
|
|
}
|
|
|
|
func (vm *VM) ConditionsSet() mapset.Set[v1.ConditionType] {
|
|
return vm.conditions
|
|
}
|
|
|
|
func (vm *VM) Conditions() []v1.Condition {
|
|
// Only expose a minimum amount of conditions necessary
|
|
// for the Orchard Controller to make decisions
|
|
return []v1.Condition{
|
|
vm.conditionTypeToCondition(v1.ConditionTypeRunning),
|
|
}
|
|
}
|
|
|
|
func (vm *VM) conditionTypeToCondition(conditionType v1.ConditionType) v1.Condition {
|
|
var conditionState v1.ConditionState
|
|
|
|
if vm.ConditionsSet().ContainsOne(conditionType) {
|
|
conditionState = v1.ConditionStateTrue
|
|
} else {
|
|
conditionState = v1.ConditionStateFalse
|
|
}
|
|
|
|
return v1.Condition{
|
|
Type: conditionType,
|
|
State: conditionState,
|
|
}
|
|
}
|
|
|
|
func (vm *VM) Shell(
|
|
ctx context.Context,
|
|
sshUser string,
|
|
sshPassword string,
|
|
script string,
|
|
env map[string]string,
|
|
consumeLine func(line string),
|
|
dialer dialer.Dialer,
|
|
getIP func(ctx context.Context) (string, error),
|
|
) error {
|
|
var sess *ssh.Session
|
|
|
|
// Set default user and password if not provided
|
|
if sshUser == "" && sshPassword == "" {
|
|
sshUser = "admin"
|
|
sshPassword = "admin"
|
|
}
|
|
|
|
// Configure SSH client
|
|
sshConfig := &ssh.ClientConfig{
|
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
return nil
|
|
},
|
|
User: sshUser,
|
|
Auth: []ssh.AuthMethod{
|
|
ssh.Password(sshPassword),
|
|
},
|
|
}
|
|
|
|
if err := retry.Do(func() error {
|
|
ip, err := getIP(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get VM's IP: %w", err)
|
|
}
|
|
|
|
addr := ip + ":22"
|
|
|
|
dialCtx, dialCtxCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer dialCtxCancel()
|
|
|
|
var netConn net.Conn
|
|
|
|
if dialer != nil {
|
|
netConn, err = dialer.DialContext(dialCtx, "tcp", addr)
|
|
} else {
|
|
dialer := net.Dialer{}
|
|
|
|
netConn, err = dialer.DialContext(dialCtx, "tcp", addr)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to dial %s: %w", addr, err)
|
|
}
|
|
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(netConn, addr, sshConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("SSH handshake with %s failed: %w", addr, err)
|
|
}
|
|
|
|
sshClient := ssh.NewClient(sshConn, chans, reqs)
|
|
|
|
sess, err = sshClient.NewSession()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open an SSH session on %s: %w", addr, err)
|
|
}
|
|
|
|
return nil
|
|
}, retry.Context(ctx), retry.OnRetry(func(n uint, err error) {
|
|
consumeLine(fmt.Sprintf("attempt %d to establish SSH connection failed: %v", n, err))
|
|
})); err != nil {
|
|
return fmt.Errorf("failed to establish SSH connection: %w", err)
|
|
}
|
|
|
|
// Log output from the virtual machine
|
|
stdout, err := sess.StdoutPipe()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: while opening stdout pipe: %v", ErrVMFailed, err)
|
|
}
|
|
stderr, err := sess.StderrPipe()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: while opening stderr pipe: %v", ErrVMFailed, err)
|
|
}
|
|
var outputReaderWG sync.WaitGroup
|
|
outputReaderWG.Add(1)
|
|
go func() {
|
|
output := io.MultiReader(stdout, stderr)
|
|
|
|
scanner := bufio.NewScanner(output)
|
|
|
|
for scanner.Scan() {
|
|
consumeLine(scanner.Text())
|
|
}
|
|
outputReaderWG.Done()
|
|
}()
|
|
|
|
stdinBuf, err := sess.StdinPipe()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: while opening stdin pipe: %v", ErrVMFailed, err)
|
|
}
|
|
|
|
// start a login shell so all the customization from ~/.zprofile will be picked up
|
|
err = sess.Shell()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: failed to start a shell: %v", ErrVMFailed, err)
|
|
}
|
|
|
|
var scriptBuilder strings.Builder
|
|
|
|
scriptBuilder.WriteString("set -e\n")
|
|
// don't use sess.Setenv since it requires non-default SSH server configuration
|
|
for key, value := range env {
|
|
scriptBuilder.WriteString("export " + key + "=\"" + value + "\"\n")
|
|
}
|
|
scriptBuilder.WriteString(script)
|
|
scriptBuilder.WriteString("\nexit\n")
|
|
|
|
_, err = stdinBuf.Write([]byte(scriptBuilder.String()))
|
|
if err != nil {
|
|
return fmt.Errorf("%w: failed to start script: %v", ErrVMFailed, err)
|
|
}
|
|
outputReaderWG.Wait()
|
|
return sess.Wait()
|
|
}
|
|
|
|
func (vm *VM) RunScript(
|
|
ctx context.Context,
|
|
sshUser string,
|
|
sshPassword string,
|
|
script *v1.VMScript,
|
|
eventStreamer *client.EventStreamer,
|
|
dialer dialer.Dialer,
|
|
getIP func(ctx context.Context) (string, error),
|
|
) {
|
|
if eventStreamer != nil {
|
|
defer func() {
|
|
if err := eventStreamer.Close(); err != nil {
|
|
vm.logger.Errorf("errored during streaming events for startup script: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
consumeLine := func(line string) {
|
|
if eventStreamer == nil {
|
|
return
|
|
}
|
|
|
|
eventStreamer.Stream(v1.Event{
|
|
Kind: v1.EventKindLogLine,
|
|
Timestamp: time.Now().Unix(),
|
|
Payload: line,
|
|
})
|
|
}
|
|
|
|
err := vm.Shell(ctx, sshUser, sshPassword, script.ScriptContent, script.Env, consumeLine, dialer, getIP)
|
|
if err != nil {
|
|
vm.SetErr(fmt.Errorf("%w: failed to run startup script: %v", ErrVMFailed, err))
|
|
}
|
|
}
|