orchard/internal/worker/worker.go

338 lines
8.6 KiB
Go

package worker
import (
"context"
"errors"
"fmt"
"github.com/avast/retry-go/v4"
"github.com/cirruslabs/orchard/internal/worker/iokitregistry"
"github.com/cirruslabs/orchard/internal/worker/vmmanager"
"github.com/cirruslabs/orchard/pkg/client"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"github.com/hashicorp/go-multierror"
"go.uber.org/zap"
"google.golang.org/grpc/metadata"
"os"
"time"
)
const pollInterval = 5 * time.Second
var ErrPollFailed = errors.New("failed to poll controller")
var ErrRegistrationFailed = errors.New("failed to register worker on the controller")
type Worker struct {
name string
vmm *vmmanager.VMManager
client *client.Client
logger *zap.SugaredLogger
}
func New(client *client.Client, opts ...Option) (*Worker, error) {
worker := &Worker{
client: client,
vmm: vmmanager.New(),
}
// Apply options
for _, opt := range opts {
opt(worker)
}
// Apply defaults
if worker.name == "" {
hostname, err := os.Hostname()
if err != nil {
return nil, err
}
worker.name = hostname
}
if worker.logger == nil {
worker.logger = zap.NewNop().Sugar()
}
return worker, nil
}
func (worker *Worker) Run(ctx context.Context) error {
for {
if err := worker.runNewSession(ctx); err != nil {
return err
}
}
}
func (worker *Worker) runNewSession(ctx context.Context) error {
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := worker.registerWorker(subCtx); err != nil {
worker.logger.Warnf("failed to register worker: %v", err)
return ErrRegistrationFailed
}
go func() {
_ = retry.Do(func() error {
return worker.watchRPC(subCtx)
}, retry.OnRetry(func(n uint, err error) {
worker.logger.Warnf("failed to watch RPC: %v", err)
}), retry.Context(subCtx), retry.Attempts(0))
}()
tickCh := time.NewTicker(pollInterval)
for {
if err := worker.updateWorker(ctx); err != nil {
worker.logger.Errorf("failed to update worker resource: %v", err)
return nil
}
if err := worker.syncVMs(subCtx); err != nil {
worker.logger.Warnf("failed to sync VMs: %v", err)
return nil
}
select {
case <-tickCh.C:
// continue
case <-subCtx.Done():
return subCtx.Err()
}
}
}
func (worker *Worker) registerWorker(ctx context.Context) error {
platformUUID, err := iokitregistry.PlatformUUID()
if err != nil {
return err
}
_, err = worker.client.Workers().Create(ctx, v1.Worker{
Meta: v1.Meta{
Name: worker.name,
},
LastSeen: time.Now(),
MachineID: platformUUID,
})
if err != nil {
return err
}
worker.logger.Infof("registered worker %s", worker.name)
return nil
}
func (worker *Worker) updateWorker(ctx context.Context) error {
workerResource, err := worker.client.Workers().Get(ctx, worker.name)
if err != nil {
return fmt.Errorf("%w: failed to retrieve worker from the API: %v", ErrPollFailed, err)
}
worker.logger.Debugf("got worker from the API")
workerResource.LastSeen = time.Now()
if _, err := worker.client.Workers().Update(ctx, *workerResource); err != nil {
return fmt.Errorf("%w: failed to update worker in the API: %v", ErrPollFailed, err)
}
worker.logger.Debugf("updated worker in the API")
return nil
}
func (worker *Worker) syncVMs(ctx context.Context) error {
remoteVMs, err := worker.client.VMs().FindForWorker(ctx, worker.name)
if err != nil {
return err
}
worker.logger.Infof("syncing %d VMs...", len(remoteVMs))
// check if need to stop any of the VMs
for _, vmResource := range remoteVMs {
if vmResource.Status == v1.VMStatusStopping && worker.vmm.Exists(vmResource) {
if err := worker.stopVM(vmResource); err != nil {
return err
}
}
}
// then, handle pending VMs first
for _, vmResource := range remoteVMs {
// handle pending VMs
if vmResource.Status == v1.VMStatusPending && !worker.vmm.Exists(vmResource) {
if err := worker.createVM(ctx, vmResource); err != nil {
return err
}
}
}
// lastly, try to sync local VMs with the remote ones
for _, vm := range worker.vmm.List() {
remoteVM, ok := remoteVMs[vm.Resource.UID]
if !ok {
if err := worker.deleteVM(vm.Resource); err != nil {
return err
}
} else if remoteVM.Status != v1.VMStatusFailed && vm.RunError != nil {
remoteVM.Status = v1.VMStatusFailed
remoteVM.StatusMessage = fmt.Sprintf("failed to run VM: %v", vm.RunError)
updatedVM, err := worker.client.VMs().Update(ctx, vm.Resource)
if err != nil {
return err
}
vm.Resource = *updatedVM
}
}
return nil
}
func (worker *Worker) deleteVM(vmResource v1.VM) error {
worker.logger.Debugf("deleting VM %s (%s)", vmResource.Name, vmResource.UID)
if !vmResource.TerminalState() {
if err := worker.stopVM(vmResource); err != nil {
return err
}
}
// Delete VM locally, report to the controller
if worker.vmm.Exists(vmResource) {
if err := worker.vmm.Delete(vmResource); err != nil {
return err
}
}
worker.logger.Infof("deleted VM %s (%s)", vmResource.Name, vmResource.UID)
return nil
}
func (worker *Worker) createVM(ctx context.Context, vmResource v1.VM) error {
worker.logger.Debugf("creating VM %s (%s)", vmResource.Name, vmResource.UID)
// Create or update VM locally
vm, err := worker.vmm.Create(ctx, vmResource, worker.logger)
if err != nil {
vmResource.Status = v1.VMStatusFailed
vmResource.StatusMessage = fmt.Sprintf("VM creation failed: %v", err)
_, updateErr := worker.client.VMs().Update(context.Background(), vmResource)
if updateErr != nil {
worker.logger.Errorf("failed to update VM %s (%s) remotely: %s", vmResource.Name, vmResource.UID, updateErr.Error())
}
return err
}
worker.logger.Infof("spawned VM %s (%s)", vmResource.Name, vmResource.UID)
vmResource.Status = v1.VMStatusRunning
_, updateErr := worker.client.VMs().Update(context.Background(), vmResource)
if updateErr != nil {
worker.logger.Errorf("failed to update VM %s (%s) remotely: %s", vmResource.Name, vmResource.UID, updateErr.Error())
}
go func() {
err := worker.execScript(vmResource, vm.Resource.StartupScript)
if err != nil {
vmResource.Status = v1.VMStatusFailed
vmResource.StatusMessage = fmt.Sprintf("failed to run script: %v", err)
_, updateErr := worker.client.VMs().Update(context.Background(), vmResource)
if updateErr != nil {
worker.logger.Errorf("failed to update VM %s (%s) remotely: %s", vmResource.Name, vmResource.UID, updateErr.Error())
}
}
}()
return nil
}
func (worker *Worker) execScript(vmResource v1.VM, script *v1.VMScript) error {
if script == nil {
return nil
}
vm, err := worker.vmm.Get(vmResource)
if err != nil {
return nil
}
eventsStreamer := worker.client.VMs().StreamEvents(vmResource.Name)
defer func() {
err := eventsStreamer.Close()
if err != nil {
worker.logger.Errorf("errored during streaming events for %s (%s): %w", vmResource.Name, vmResource.UID, err)
}
}()
err = vm.Shell(context.Background(), vmResource.Username, vmResource.Password,
script.ScriptContent, script.Env,
func(line string) {
eventsStreamer.Stream(v1.Event{
Kind: v1.EventKindLogLine,
Timestamp: time.Now().Unix(),
Payload: line,
})
})
if err != nil {
worker.logger.Errorf("failed to run script for VM %s (%s): %s", vmResource.Name, vmResource.UID, err.Error())
}
return err
}
func (worker *Worker) stopVM(vmResource v1.VM) error {
worker.logger.Debugf("stopping VM %s (%s)", vmResource.Name, vmResource.UID)
// Create or update VM locally
if !worker.vmm.Exists(vmResource) {
return nil
}
shutdownScriptErr := worker.execScript(vmResource, vmResource.ShutdownScript)
stopErr := worker.vmm.Stop(vmResource)
vmResource.Status = v1.VMStatusStopped
if stopErr != nil {
vmResource.Status = v1.VMStatusFailed
vmResource.StatusMessage = fmt.Sprintf("failed to stop vm: %v", stopErr)
}
if shutdownScriptErr != nil {
vmResource.Status = v1.VMStatusFailed
vmResource.StatusMessage = fmt.Sprintf("failed to run shutdown script: %v", shutdownScriptErr)
}
_, err := worker.client.VMs().Update(context.Background(), vmResource)
if err != nil {
worker.logger.Errorf("failed to update VM %s (%s) remotely: %s", vmResource.Name, vmResource.UID, err.Error())
}
return stopErr
}
func (worker *Worker) DeleteAllVMs() error {
var result error
for _, vm := range worker.vmm.List() {
err := vm.Stop()
if err != nil {
result = multierror.Append(result, err)
}
}
for _, vm := range worker.vmm.List() {
err := vm.Delete()
if err != nil {
result = multierror.Append(result, err)
}
}
return result
}
func (worker *Worker) GPRCMetadata() metadata.MD {
return metadata.Join(
worker.client.GPRCMetadata(),
metadata.Pairs(rpc.MetadataWorkerNameKey, worker.name),
)
}