From 60e564da88e5ef30e70601ae47b4bc5a49317e58 Mon Sep 17 00:00:00 2001 From: Nikolay Edigaryev Date: Mon, 24 Apr 2023 19:30:08 +0400 Subject: [PATCH] Implement restart policy for VMs (#83) * Implement restart policy for VMs * Do not update VM.Resource, we only use it as a read-only specification * Err()/setErr(): use atomic.Pointer instead of sync.Mutex --- go.mod | 1 + go.sum | 2 + internal/command/create/vm.go | 69 ++++-- internal/command/list/vms.go | 6 +- internal/command/pause/worker.go | 3 +- internal/command/root.go | 2 - internal/command/stop/stop.go | 16 -- internal/command/stop/vm.go | 27 -- internal/controller/api_vms.go | 12 + internal/controller/scheduler/scheduler.go | 25 +- internal/tests/integration_test.go | 72 +++--- internal/worker/ondiskname/ondiskname.go | 53 ++-- internal/worker/ondiskname/ondiskname_test.go | 13 +- internal/worker/rpc.go | 16 +- internal/worker/vmmanager/vm.go | 132 ++++++++-- internal/worker/vmmanager/vmmanager.go | 64 ++--- internal/worker/worker.go | 231 ++++++------------ pkg/client/client.go | 26 +- pkg/client/vms.go | 22 +- pkg/resource/v1/restart_policy.go | 26 ++ pkg/resource/v1/restart_policy_test.go | 26 ++ pkg/resource/v1/v1.go | 20 +- 22 files changed, 482 insertions(+), 382 deletions(-) delete mode 100644 internal/command/stop/stop.go delete mode 100644 internal/command/stop/vm.go create mode 100644 pkg/resource/v1/restart_policy.go create mode 100644 pkg/resource/v1/restart_policy_test.go diff --git a/go.mod b/go.mod index 2db2f00..618c232 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/mitchellh/go-grpc-net-conn v0.0.0-20200427190222-eb030e4876f0 github.com/penglongli/gin-metrics v0.1.10 github.com/prometheus/client_golang v1.15.0 + github.com/samber/lo v1.38.1 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/spf13/cobra v1.6.0 github.com/stretchr/testify v1.8.1 diff --git a/go.sum b/go.sum index d34b6c0..230960e 100644 --- a/go.sum +++ b/go.sum @@ -396,6 +396,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= +github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= diff --git a/internal/command/create/vm.go b/internal/command/create/vm.go index 5e942ad..76da537 100644 --- a/internal/command/create/vm.go +++ b/internal/command/create/vm.go @@ -6,6 +6,8 @@ import ( "github.com/cirruslabs/orchard/pkg/client" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/spf13/cobra" + "os" + "strings" ) var ErrVMFailed = errors.New("failed to create VM") @@ -16,7 +18,9 @@ var memory uint64 var netSoftnet bool var netBridged string var headless bool -var stringToStringResources map[string]string +var resources map[string]string +var restartPolicy string +var startupScript string func newCreateVMCommand() *cobra.Command { command := &cobra.Command{ @@ -32,8 +36,14 @@ func newCreateVMCommand() *cobra.Command { command.PersistentFlags().BoolVar(&netSoftnet, "net-softnet", false, "whether to use Softnet network isolation") command.PersistentFlags().StringVar(&netBridged, "net-bridged", "", "whether to use Bridged network mode") command.PersistentFlags().BoolVar(&headless, "headless", true, "whether to run without graphics") - command.PersistentFlags().StringToStringVar(&stringToStringResources, "resources", map[string]string{}, + command.PersistentFlags().StringToStringVar(&resources, "resources", map[string]string{}, "resources to request for this VM") + command.PersistentFlags().StringVar(&restartPolicy, "restart-policy", "Never", + "restart policy for this VM: specify \"Never\" to never restart or \"OnFailure\" "+ + "to only restart when the VM fails") + command.PersistentFlags().StringVar(&startupScript, "startup-script", "", + "startup script (e.g. --startup-script=\"sync\") or a path to a script file prefixed with \"@\" "+ + "(e.g. \"--startup-script=@script.sh\")") return command } @@ -41,18 +51,7 @@ func newCreateVMCommand() *cobra.Command { func runCreateVM(cmd *cobra.Command, args []string) error { name := args[0] - // Convert resources - resources, err := v1.NewResourcesFromStringToString(stringToStringResources) - if err != nil { - return fmt.Errorf("%w: %v", ErrVMFailed, err) - } - - client, err := client.New() - if err != nil { - return err - } - - return client.VMs().Create(cmd.Context(), &v1.VM{ + vm := &v1.VM{ Meta: v1.Meta{ Name: name, }, @@ -62,6 +61,44 @@ func runCreateVM(cmd *cobra.Command, args []string) error { NetSoftnet: netSoftnet, NetBridged: netBridged, Headless: headless, - Resources: resources, - }) + } + + // Convert resources + var err error + + vm.Resources, err = v1.NewResourcesFromStringToString(resources) + if err != nil { + return fmt.Errorf("%w: %v", ErrVMFailed, err) + } + + // Convert restart policy + vm.RestartPolicy, err = v1.NewRestartPolicyFromString(restartPolicy) + if err != nil { + return fmt.Errorf("%w: %v", ErrVMFailed, err) + } + + // Convert startup script, optionally reading it from the file system + const scriptFilePrefix = "@" + + if strings.HasPrefix(startupScript, scriptFilePrefix) { + startupScriptBytes, err := os.ReadFile(strings.TrimPrefix(startupScript, scriptFilePrefix)) + if err != nil { + return err + } + + vm.StartupScript = &v1.VMScript{ + ScriptContent: string(startupScriptBytes), + } + } else if startupScript != "" { + vm.StartupScript = &v1.VMScript{ + ScriptContent: startupScript, + } + } + + client, err := client.New() + if err != nil { + return err + } + + return client.VMs().Create(cmd.Context(), vm) } diff --git a/internal/command/list/vms.go b/internal/command/list/vms.go index 4cd1364..5fc0576 100644 --- a/internal/command/list/vms.go +++ b/internal/command/list/vms.go @@ -38,10 +38,12 @@ func runListVMs(cmd *cobra.Command, args []string) error { table := uitable.New() - table.AddRow("Name", "Image", "Status") + table.AddRow("Name", "Image", "Status", "Restart policy") for _, vm := range vms { - table.AddRow(vm.Name, vm.Image, vm.Status) + restartPolicyInfo := fmt.Sprintf("%s (%d restarts)", vm.RestartPolicy, vm.RestartCount) + + table.AddRow(vm.Name, vm.Image, vm.Status, restartPolicyInfo) } fmt.Println(table) diff --git a/internal/command/pause/worker.go b/internal/command/pause/worker.go index cb3e539..cdc30d3 100644 --- a/internal/command/pause/worker.go +++ b/internal/command/pause/worker.go @@ -5,7 +5,6 @@ import ( "github.com/cirruslabs/orchard/pkg/client" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/spf13/cobra" - "golang.org/x/exp/maps" "golang.org/x/exp/slices" "time" ) @@ -58,7 +57,7 @@ func runPauseWorker(cmd *cobra.Command, args []string) error { return err } - hasActiveVMs := slices.ContainsFunc(maps.Values(vms), func(vm v1.VM) bool { + hasActiveVMs := slices.ContainsFunc(vms, func(vm v1.VM) bool { return !vm.TerminalState() }) if !hasActiveVMs { diff --git a/internal/command/root.go b/internal/command/root.go index 131d89d..d423a43 100644 --- a/internal/command/root.go +++ b/internal/command/root.go @@ -13,7 +13,6 @@ import ( "github.com/cirruslabs/orchard/internal/command/portforward" "github.com/cirruslabs/orchard/internal/command/resume" "github.com/cirruslabs/orchard/internal/command/ssh" - "github.com/cirruslabs/orchard/internal/command/stop" "github.com/cirruslabs/orchard/internal/command/vnc" "github.com/cirruslabs/orchard/internal/command/worker" "github.com/cirruslabs/orchard/internal/version" @@ -39,7 +38,6 @@ func NewRootCmd() *cobra.Command { ssh.NewCommand(), vnc.NewCommand(), deletepkg.NewCommand(), - stop.NewCommand(), ) addGroupedCommands(command, "Administrative Tasks:", diff --git a/internal/command/stop/stop.go b/internal/command/stop/stop.go deleted file mode 100644 index fb3a6ea..0000000 --- a/internal/command/stop/stop.go +++ /dev/null @@ -1,16 +0,0 @@ -package stop - -import ( - "github.com/spf13/cobra" -) - -func NewCommand() *cobra.Command { - command := &cobra.Command{ - Use: "stop", - Short: "Stop resources", - } - - command.AddCommand(newStopVMCommand()) - - return command -} diff --git a/internal/command/stop/vm.go b/internal/command/stop/vm.go deleted file mode 100644 index cf67fc7..0000000 --- a/internal/command/stop/vm.go +++ /dev/null @@ -1,27 +0,0 @@ -package stop - -import ( - "github.com/cirruslabs/orchard/pkg/client" - "github.com/spf13/cobra" -) - -func newStopVMCommand() *cobra.Command { - return &cobra.Command{ - Use: "vm NAME", - Short: "Stop a VM", - Args: cobra.ExactArgs(1), - RunE: runStopVM, - } -} - -func runStopVM(cmd *cobra.Command, args []string) error { - name := args[0] - - client, err := client.New() - if err != nil { - return err - } - - _, err = client.VMs().Stop(cmd.Context(), name) - return err -} diff --git a/internal/controller/api_vms.go b/internal/controller/api_vms.go index 3c43ab7..b586998 100644 --- a/internal/controller/api_vms.go +++ b/internal/controller/api_vms.go @@ -37,6 +37,8 @@ func (controller *Controller) createVM(ctx *gin.Context) responder.Responder { vm.Status = v1.VMStatusPending vm.CreatedAt = time.Now() + vm.RestartedAt = time.Time{} + vm.RestartCount = 0 vm.UID = uuid.New().String() // Provide resource defaults @@ -47,6 +49,16 @@ func (controller *Controller) createVM(ctx *gin.Context) responder.Responder { vm.Resources[v1.ResourceTartVMs] = 1 } + // Validate restart policy and provide a default value if it's missing + if vm.RestartPolicy != "" { + if _, err := v1.NewRestartPolicyFromString(string(vm.RestartPolicy)); err != nil { + return responder.JSON(http.StatusPreconditionFailed, + NewErrorResponse("unsupported restart policy: %q", vm.RestartPolicy)) + } + } else { + vm.RestartPolicy = v1.RestartPolicyNever + } + response := controller.storeUpdate(func(txn storepkg.Transaction) responder.Responder { // Does the VM resource with this name already exists? _, err := txn.GetVM(vm.Name) diff --git a/internal/controller/scheduler/scheduler.go b/internal/controller/scheduler/scheduler.go index 72e94b9..dbcf7a8 100644 --- a/internal/controller/scheduler/scheduler.go +++ b/internal/controller/scheduler/scheduler.go @@ -13,7 +13,11 @@ import ( "time" ) -const schedulerInterval = 5 * time.Second +const ( + schedulerInterval = 5 * time.Second + + schedulerVMRestartDelay = 15 * time.Second +) var ( schedulerLoopIterationStat = promauto.NewCounter(prometheus.CounterOpts{ @@ -205,6 +209,25 @@ func (scheduler *Scheduler) healthCheckingLoopIteration() error { } func (scheduler *Scheduler) healthCheckVM(txn storepkg.Transaction, nameToWorker map[string]v1.Worker, vm v1.VM) error { + logger := scheduler.logger.With("vm_name", vm.Name, "vm_uid", vm.UID, "vm_restart_count", vm.RestartCount) + + // Schedule a VM restart if the restart policy mandates it + needsRestart := vm.RestartPolicy == v1.RestartPolicyOnFailure && + vm.Status == v1.VMStatusFailed && + time.Since(vm.RestartedAt) > schedulerVMRestartDelay + + if needsRestart { + logger.Debugf("restarting VM") + + vm.Status = v1.VMStatusPending + vm.StatusMessage = "" + vm.Worker = "" + vm.RestartedAt = time.Now() + vm.RestartCount++ + + return txn.SetVM(vm) + } + worker, ok := nameToWorker[vm.Worker] if !ok { vm.Status = v1.VMStatusFailed diff --git a/internal/tests/integration_test.go b/internal/tests/integration_test.go index 2bc9804..a9edef7 100644 --- a/internal/tests/integration_test.go +++ b/internal/tests/integration_test.go @@ -18,6 +18,7 @@ import ( "golang.org/x/exp/slices" "net" "net/http" + "strings" "testing" "time" ) @@ -43,9 +44,6 @@ func TestSingleVM(t *testing.T) { ScriptContent: "echo \"Hello, $FOO!\"", Env: map[string]string{"FOO": "Bar"}, }, - ShutdownScript: &v1.VMScript{ - ScriptContent: "echo \"Buy!\"", - }, }) if err != nil { t.Fatal(err) @@ -77,24 +75,22 @@ func TestSingleVM(t *testing.T) { } assert.Equal(t, []string{"Hello, Bar!"}, logLines) - stoppingVM, err := devClient.VMs().Stop(context.Background(), "test-vm") - if err != nil { - t.Fatal(err) - } - assert.Equal(t, v1.VMStatusStopping, stoppingVM.Status) + // Ensure that the VM exists on disk before deleting it + require.True(t, hasVMByPredicate(t, func(info tart.VMInfo) bool { + return strings.Contains(info.Name, runningVM.UID) + }, nil)) + + // Delete the VM from the controller + require.NoError(t, devClient.VMs().Delete(context.Background(), "test-vm")) + + // Ensure that the worker has deleted this VM from disk assert.True(t, Wait(2*time.Minute, func() bool { - vm, err := devClient.VMs().Get(context.Background(), "test-vm") - if err != nil { - t.Fatal(err) - } - t.Logf("Waiting for the VM to stop. Current status: %s", vm.Status) - return vm.Status == v1.VMStatusStopped - }), "failed to stop a VM") - logLines, err = devClient.VMs().Logs(context.Background(), "test-vm") - if err != nil { - t.Fatal(err) - } - assert.Equal(t, []string{"Hello, Bar!", "Buy!"}, logLines) + t.Logf("Waiting for the VM to be garbage collected...") + + return !hasVMByPredicate(t, func(info tart.VMInfo) bool { + return strings.Contains(info.Name, runningVM.UID) + }, nil) + }), "VM was not garbage collected in a timely manner") } func TestFailedStartupScript(t *testing.T) { @@ -133,7 +129,8 @@ func TestFailedStartupScript(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, "failed to run script: Process exited with status 123", runningVM.StatusMessage) + assert.Contains(t, runningVM.StatusMessage, + "failed to run startup script: Process exited with status 123") } func Wait(duration time.Duration, condition func() bool) bool { @@ -172,7 +169,7 @@ func StartIntegrationTestEnvironmentWithAdditionalOpts( t.Fatal(err) } t.Cleanup(func() { - _ = devWorker.DeleteAllVMs() + _ = devWorker.Close() }) devContext, cancelDevFunc := context.WithCancel(context.Background()) t.Cleanup(cancelDevFunc) @@ -401,21 +398,13 @@ func TestVMGarbageCollection(t *testing.T) { require.NoError(t, err) // Create on-disk Tart VM that looks like it's managed by Orchard - vmName := ondiskname.New("test", uuid.New().String()).String() + vmName := ondiskname.New("test", uuid.New().String(), 0).String() _, _, err = tart.Tart(ctx, logger.Sugar(), "clone", "ghcr.io/cirruslabs/macos-ventura-base:latest", vmName) require.NoError(t, err) // Make sure that this VM exists - hasVM := func(name string) bool { - vmInfos, err := tart.List(ctx, logger.Sugar()) - require.NoError(t, err) - - return slices.ContainsFunc(vmInfos, func(vmInfo tart.VMInfo) bool { - return vmInfo.Name == name - }) - } - require.True(t, hasVM(vmName)) + require.True(t, hasVM(t, vmName, logger)) // Start the Orchard Worker _ = StartIntegrationTestEnvironment(t) @@ -424,6 +413,23 @@ func TestVMGarbageCollection(t *testing.T) { require.True(t, Wait(2*time.Minute, func() bool { t.Logf("Waiting for the on-disk VM to be cleaned up by the worker") - return !hasVM(vmName) + return !hasVM(t, vmName, logger) }), "failed to wait for the VM %s to be garbage-collected", vmName) } + +func hasVM(t *testing.T, name string, logger *zap.Logger) bool { + return hasVMByPredicate(t, func(vmInfo tart.VMInfo) bool { + return vmInfo.Name == name + }, logger) +} + +func hasVMByPredicate(t *testing.T, predicate func(tart.VMInfo) bool, logger *zap.Logger) bool { + if logger == nil { + logger = zap.Must(zap.NewDevelopment()) + } + + vmInfos, err := tart.List(context.Background(), logger.Sugar()) + require.NoError(t, err) + + return slices.ContainsFunc(vmInfos, predicate) +} diff --git a/internal/worker/ondiskname/ondiskname.go b/internal/worker/ondiskname/ondiskname.go index c7b602a..a458d5b 100644 --- a/internal/worker/ondiskname/ondiskname.go +++ b/internal/worker/ondiskname/ondiskname.go @@ -3,6 +3,8 @@ package ondiskname import ( "errors" "fmt" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "strconv" "strings" ) @@ -12,19 +14,34 @@ var ( ) const ( - prefix = "orchard" - numHyphensInUUID = 5 + prefix = "orchard" + + numPartsPrefix = 1 + numPartsName = 1 + numPartsUUID = 5 + numPartsRestartCount = 1 + numPartsTotal = numPartsPrefix + numPartsName + numPartsUUID + numPartsRestartCount ) type OnDiskName struct { - Name string - UID string + Name string + UID string + RestartCount uint64 } -func New(name string, uid string) OnDiskName { +func New(name string, uid string, restartCount uint64) OnDiskName { return OnDiskName{ - Name: name, - UID: uid, + Name: name, + UID: uid, + RestartCount: restartCount, + } +} + +func NewFromResource(vm v1.VM) OnDiskName { + return OnDiskName{ + Name: vm.Name, + UID: vm.UID, + RestartCount: vm.RestartCount, } } @@ -35,9 +52,9 @@ func Parse(s string) (OnDiskName, error) { return OnDiskName{}, ErrNotManagedByOrchard } - if len(splits) < 7 { - return OnDiskName{}, fmt.Errorf("%w: name should contain at least 7 parts delimited by \"-\"", - ErrInvalidOnDiskName) + if len(splits) < numPartsTotal { + return OnDiskName{}, fmt.Errorf("%w: name should contain at least %d parts delimited by \"-\"", + ErrInvalidOnDiskName, numPartsTotal) } if splits[0] != prefix { @@ -45,14 +62,22 @@ func Parse(s string) (OnDiskName, error) { ErrInvalidOnDiskName, prefix) } - uuidStart := len(splits) - numHyphensInUUID + uuidStart := len(splits) - numPartsUUID - numPartsRestartCount + + restartCountRaw := splits[uuidStart+numPartsUUID] + restartCount, err := strconv.ParseUint(restartCountRaw, 10, 64) + if err != nil { + return OnDiskName{}, fmt.Errorf("%w: invalid restart count %q", + ErrInvalidOnDiskName, restartCountRaw) + } return OnDiskName{ - Name: strings.Join(splits[1:uuidStart], "-"), - UID: strings.Join(splits[uuidStart:], "-"), + Name: strings.Join(splits[1:uuidStart], "-"), + UID: strings.Join(splits[uuidStart:uuidStart+numPartsUUID], "-"), + RestartCount: restartCount, }, nil } func (odn OnDiskName) String() string { - return fmt.Sprintf("%s-%s-%s", prefix, odn.Name, odn.UID) + return fmt.Sprintf("%s-%s-%s-%d", prefix, odn.Name, odn.UID, odn.RestartCount) } diff --git a/internal/worker/ondiskname/ondiskname_test.go b/internal/worker/ondiskname/ondiskname_test.go index 1e2f5ba..4ea2439 100644 --- a/internal/worker/ondiskname/ondiskname_test.go +++ b/internal/worker/ondiskname/ondiskname_test.go @@ -1,14 +1,23 @@ package ondiskname_test import ( + "fmt" "github.com/cirruslabs/orchard/internal/worker/ondiskname" "github.com/google/uuid" "github.com/stretchr/testify/require" "testing" ) +func TestOnDiskNameFromStaticString(t *testing.T) { + uuid := uuid.New().String() + + parsedOnDiskName, err := ondiskname.Parse(fmt.Sprintf("orchard-vm-name-%s-42", uuid)) + require.NoError(t, err) + require.Equal(t, ondiskname.OnDiskName{"vm-name", uuid, 42}, parsedOnDiskName) +} + func TestOnDiskNameUUID(t *testing.T) { - onDiskNameOriginal := ondiskname.New("test-vm--", uuid.New().String()) + onDiskNameOriginal := ondiskname.New("test-vm--", uuid.New().String(), 0) onDiskNameParsed, err := ondiskname.Parse(onDiskNameOriginal.String()) require.NoError(t, err) @@ -17,7 +26,7 @@ func TestOnDiskNameUUID(t *testing.T) { } func TestOnDiskNameNonUUID(t *testing.T) { - onDiskNameOriginal := ondiskname.New("some-vm", "some-uid") + onDiskNameOriginal := ondiskname.New("some-vm", "some-uid", 0) _, err := ondiskname.Parse(onDiskNameOriginal.String()) require.Error(t, err) diff --git a/internal/worker/rpc.go b/internal/worker/rpc.go index ea4a507..98dd0f5 100644 --- a/internal/worker/rpc.go +++ b/internal/worker/rpc.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "github.com/cirruslabs/orchard/internal/proxy" - v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" "github.com/cirruslabs/orchard/rpc" "google.golang.org/grpc/keepalive" "google.golang.org/protobuf/types/known/emptypb" @@ -16,6 +16,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/metadata" "net" + + "github.com/samber/lo" ) func (worker *Worker) watchRPC(ctx context.Context) error { @@ -31,7 +33,7 @@ func (worker *Worker) watchRPC(ctx context.Context) error { client := rpc.NewControllerClient(conn) - ctxWithMetadata := metadata.NewOutgoingContext(ctx, worker.GPRCMetadata()) + ctxWithMetadata := metadata.NewOutgoingContext(ctx, worker.grpcMetadata()) stream, err := client.Watch(ctxWithMetadata, &emptypb.Empty{}) if err != nil { @@ -48,7 +50,7 @@ func (worker *Worker) watchRPC(ctx context.Context) error { case *rpc.WatchInstruction_PortForwardAction: go worker.handlePortForward(ctxWithMetadata, client, action.PortForwardAction) case *rpc.WatchInstruction_SyncVmsAction: - worker.RequestVMSyncing() + worker.requestVMSyncing() } } } @@ -62,7 +64,7 @@ func (worker *Worker) handlePortForward( defer cancel() grpcMetadata := metadata.Join( - worker.GPRCMetadata(), + worker.grpcMetadata(), metadata.Pairs(rpc.MetadataWorkerPortForwardingSessionKey, portForwardAction.Session), ) ctxWithMetadata := metadata.NewOutgoingContext(subCtx, grpcMetadata) @@ -74,10 +76,10 @@ func (worker *Worker) handlePortForward( } // Obtain VM - vm, err := worker.vmm.Get(v1.VM{ - UID: portForwardAction.VmUid, + vm, ok := lo.Find(worker.vmm.List(), func(item *vmmanager.VM) bool { + return item.Resource.UID == portForwardAction.VmUid }) - if err != nil { + if !ok { worker.logger.Warnf("port forwarding failed: failed to get the VM: %v", err) return diff --git a/internal/worker/vmmanager/vm.go b/internal/worker/vmmanager/vm.go index 6aa48d6..f8d7ccc 100644 --- a/internal/worker/vmmanager/vm.go +++ b/internal/worker/vmmanager/vm.go @@ -8,6 +8,7 @@ import ( "github.com/avast/retry-go" "github.com/cirruslabs/orchard/internal/worker/ondiskname" "github.com/cirruslabs/orchard/internal/worker/tart" + "github.com/cirruslabs/orchard/pkg/client" "github.com/cirruslabs/orchard/pkg/resource/v1" "go.uber.org/zap" "golang.org/x/crypto/ssh" @@ -16,15 +17,18 @@ import ( "strconv" "strings" "sync" + "sync/atomic" + "time" ) -var ErrVMFailed = errors.New("VM errored") +var ErrVMFailed = errors.New("VM failed") type VM struct { - id string - Resource v1.VM - logger *zap.SugaredLogger - RunError error + onDiskName ondiskname.OnDiskName + Resource v1.VM + logger *zap.SugaredLogger + + err atomic.Pointer[error] ctx context.Context cancel context.CancelFunc @@ -32,13 +36,22 @@ type VM struct { wg *sync.WaitGroup } -func NewVM(ctx context.Context, vmResource v1.VM, logger *zap.SugaredLogger) (*VM, error) { +func NewVM( + ctx context.Context, + vmResource v1.VM, + eventStreamer *client.EventStreamer, + logger *zap.SugaredLogger, +) (*VM, error) { vmContext, vmContextCancel := context.WithCancel(context.Background()) vm := &VM{ - id: ondiskname.New(vmResource.Name, vmResource.UID).String(), - Resource: vmResource, - logger: logger, + onDiskName: ondiskname.NewFromResource(vmResource), + Resource: vmResource, + logger: logger.With( + "vm_uid", vmResource.UID, + "vm_name", vmResource.Name, + "vm_restart_count", vmResource.RestartCount, + ), ctx: vmContext, cancel: vmContextCancel, @@ -47,6 +60,8 @@ func NewVM(ctx context.Context, vmResource v1.VM, logger *zap.SugaredLogger) (*V } // Clone the VM so `run` and `ip` are not racing + vm.logger.Debugf("creating VM") + if err := vm.cloneAndConfigure(ctx); err != nil { return nil, fmt.Errorf("failed to clone the VM: %w", err) } @@ -56,24 +71,53 @@ func NewVM(ctx context.Context, vmResource v1.VM, logger *zap.SugaredLogger) (*V go func() { defer vm.wg.Done() + vm.logger.Debugf("spawned VM") + if err := vm.run(vm.ctx); err != nil { - logger.Errorf("VM %s failed: %v", vm.id, err) - vm.RunError = err + vm.setErr(fmt.Errorf("%w: %v", ErrVMFailed, err)) } + + vm.setErr(fmt.Errorf("%w: VM exited unexpectedly", ErrVMFailed)) }() + if vm.Resource.StartupScript != nil { + go vm.runScript(vm.Resource.StartupScript, eventStreamer) + } + return vm, nil } +func (vm *VM) OnDiskName() ondiskname.OnDiskName { + return vm.onDiskName +} + +func (vm *VM) id() string { + return vm.onDiskName.String() +} + +func (vm *VM) Err() error { + if err := vm.err.Load(); err != nil { + return *err + } + + return nil +} + +func (vm *VM) setErr(err error) { + if vm.err.CompareAndSwap(nil, &err) { + vm.logger.Error(err) + } +} + func (vm *VM) cloneAndConfigure(ctx context.Context) error { - _, _, err := tart.Tart(ctx, vm.logger, "clone", vm.Resource.Image, vm.id) + _, _, err := tart.Tart(ctx, vm.logger, "clone", vm.Resource.Image, vm.id()) if err != nil { return err } if vm.Resource.Memory != 0 { _, _, err = tart.Tart(ctx, vm.logger, "set", "--memory", - strconv.FormatUint(vm.Resource.Memory, 10), vm.id) + strconv.FormatUint(vm.Resource.Memory, 10), vm.id()) if err != nil { return err } @@ -81,7 +125,7 @@ func (vm *VM) cloneAndConfigure(ctx context.Context) error { if vm.Resource.CPU != 0 { _, _, err = tart.Tart(ctx, vm.logger, "set", "--cpu", - strconv.FormatUint(vm.Resource.CPU, 10), vm.id) + strconv.FormatUint(vm.Resource.CPU, 10), vm.id()) if err != nil { return err } @@ -103,7 +147,7 @@ func (vm *VM) run(ctx context.Context) error { runArgs = append(runArgs, "--no-graphics") } - runArgs = append(runArgs, vm.id) + runArgs = append(runArgs, vm.id()) _, _, err := tart.Tart(ctx, vm.logger, runArgs...) if err != nil { return err @@ -113,7 +157,7 @@ func (vm *VM) run(ctx context.Context) error { } func (vm *VM) IP(ctx context.Context) (string, error) { - stdout, _, err := tart.Tart(ctx, vm.logger, "ip", "--wait", "60", vm.id) + stdout, _, err := tart.Tart(ctx, vm.logger, "ip", "--wait", "60", vm.id()) if err != nil { return "", err } @@ -122,7 +166,11 @@ func (vm *VM) IP(ctx context.Context) (string, error) { } func (vm *VM) Stop() error { - _, _, _ = tart.Tart(context.Background(), vm.logger, "stop", vm.id) + vm.logger.Debugf("stopping VM") + + _, _, _ = tart.Tart(context.Background(), vm.logger, "stop", vm.id()) + + vm.logger.Debugf("VM stopped") vm.cancel() @@ -132,15 +180,19 @@ func (vm *VM) Stop() error { } func (vm *VM) Delete() error { - _, _, err := tart.Tart(context.Background(), vm.logger, "delete", vm.id) + vm.logger.Debugf("deleting VM") + + _, _, err := tart.Tart(context.Background(), vm.logger, "delete", vm.id()) if err != nil { - return fmt.Errorf("%w: failed to delete VM %s: %v", ErrFailed, vm.id, err) + return fmt.Errorf("%w: failed to delete VM: %v", ErrVMFailed, err) } + vm.logger.Debugf("deleted VM") + return nil } -func (vm *VM) Shell( +func (vm *VM) shell( ctx context.Context, sshUser string, sshPassword string, @@ -191,17 +243,17 @@ func (vm *VM) Shell( sess, err := cli.NewSession() if err != nil { - return fmt.Errorf("%w: failed to open SSH session: %v", ErrFailed, err) + return fmt.Errorf("%w: failed to open SSH session: %v", ErrVMFailed, err) } // Log output from the virtual machine stdout, err := sess.StdoutPipe() if err != nil { - return fmt.Errorf("%w: while opening stdout pipe: %v", ErrFailed, err) + return fmt.Errorf("%w: while opening stdout pipe: %v", ErrVMFailed, err) } stderr, err := sess.StderrPipe() if err != nil { - return fmt.Errorf("%w: while opening stderr pipe: %v", ErrFailed, err) + return fmt.Errorf("%w: while opening stderr pipe: %v", ErrVMFailed, err) } var outputReaderWG sync.WaitGroup outputReaderWG.Add(1) @@ -218,13 +270,13 @@ func (vm *VM) Shell( stdinBuf, err := sess.StdinPipe() if err != nil { - return fmt.Errorf("%w: while opening stdin pipe: %v", ErrFailed, err) + return fmt.Errorf("%w: while opening stdin pipe: %v", ErrVMFailed, err) } // start a login shell so all the customization from ~/.zprofile will be picked up err = sess.Shell() if err != nil { - return fmt.Errorf("%w: failed to start a shell: %v", ErrFailed, err) + return fmt.Errorf("%w: failed to start a shell: %v", ErrVMFailed, err) } var scriptBuilder strings.Builder @@ -239,8 +291,36 @@ func (vm *VM) Shell( _, err = stdinBuf.Write([]byte(scriptBuilder.String())) if err != nil { - return fmt.Errorf("%w: failed to start script: %v", ErrFailed, err) + return fmt.Errorf("%w: failed to start script: %v", ErrVMFailed, err) } outputReaderWG.Wait() return sess.Wait() } + +func (vm *VM) runScript(script *v1.VMScript, eventStreamer *client.EventStreamer) { + if eventStreamer != nil { + defer func() { + if err := eventStreamer.Close(); err != nil { + vm.logger.Errorf("errored during streaming events for startup script: %v", err) + } + }() + } + + consumeLine := func(line string) { + if eventStreamer == nil { + return + } + + eventStreamer.Stream(v1.Event{ + Kind: v1.EventKindLogLine, + Timestamp: time.Now().Unix(), + Payload: line, + }) + } + + err := vm.shell(context.Background(), vm.Resource.Username, vm.Resource.Password, + script.ScriptContent, script.Env, consumeLine) + if err != nil { + vm.setErr(fmt.Errorf("%w: failed to run startup script: %v", ErrVMFailed, err)) + } +} diff --git a/internal/worker/vmmanager/vmmanager.go b/internal/worker/vmmanager/vmmanager.go index 739925d..9538d33 100644 --- a/internal/worker/vmmanager/vmmanager.go +++ b/internal/worker/vmmanager/vmmanager.go @@ -1,77 +1,41 @@ package vmmanager import ( - "context" - "errors" - "fmt" - v1 "github.com/cirruslabs/orchard/pkg/resource/v1" - "go.uber.org/zap" + "github.com/cirruslabs/orchard/internal/worker/ondiskname" ) -var ErrFailed = errors.New("VM manager failed") - type VMManager struct { - vms map[string]*VM + vms map[ondiskname.OnDiskName]*VM } func New() *VMManager { return &VMManager{ - vms: map[string]*VM{}, + vms: map[ondiskname.OnDiskName]*VM{}, } } -func (vmm *VMManager) Exists(vmResource v1.VM) bool { - _, ok := vmm.vms[vmResource.UID] +func (vmm *VMManager) Exists(key ondiskname.OnDiskName) bool { + _, ok := vmm.vms[key] return ok } -func (vmm *VMManager) Get(vmResource v1.VM) (*VM, error) { - managedVM, ok := vmm.vms[vmResource.UID] - if !ok { - return nil, fmt.Errorf("%w: VM does not exist", ErrFailed) - } +func (vmm *VMManager) Get(key ondiskname.OnDiskName) (*VM, bool) { + vm, ok := vmm.vms[key] - return managedVM, nil + return vm, ok } -func (vmm *VMManager) Create(ctx context.Context, vmResource v1.VM, logger *zap.SugaredLogger) (*VM, error) { - if _, ok := vmm.vms[vmResource.UID]; ok { - return nil, fmt.Errorf("%w: VM already exists", ErrFailed) - } - - managedVM, err := NewVM(ctx, vmResource, logger) - if err != nil { - return nil, err - } - - vmm.vms[vmResource.UID] = managedVM - - return managedVM, nil +func (vmm *VMManager) Put(key ondiskname.OnDiskName, vm *VM) { + vmm.vms[key] = vm } -func (vmm *VMManager) Stop(vmResource v1.VM) error { - managedVM, ok := vmm.vms[vmResource.UID] - if !ok { - return fmt.Errorf("%w: VM does not exist", ErrFailed) - } - - return managedVM.Stop() +func (vmm *VMManager) Delete(key ondiskname.OnDiskName) { + delete(vmm.vms, key) } -func (vmm *VMManager) Delete(vmResource v1.VM) error { - managedVM, ok := vmm.vms[vmResource.UID] - if !ok { - return fmt.Errorf("%w: VM does not exist", ErrFailed) - } - - if err := managedVM.Delete(); err != nil { - return err - } - - delete(vmm.vms, vmResource.UID) - - return nil +func (vmm *VMManager) Len() int { + return len(vmm.vms) } func (vmm *VMManager) List() []*VM { diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 5de29ef..830bc67 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -29,6 +29,7 @@ type Worker struct { syncRequested chan bool vmm *vmmanager.VMManager client *client.Client + pollTicker *time.Ticker resources v1.Resources logger *zap.SugaredLogger } @@ -36,6 +37,7 @@ type Worker struct { func New(client *client.Client, opts ...Option) (*Worker, error) { worker := &Worker{ client: client, + pollTicker: time.NewTicker(pollInterval), vmm: vmmanager.New(), syncRequested: make(chan bool, 1), } @@ -72,9 +74,33 @@ func (worker *Worker) Run(ctx context.Context) error { if err := worker.runNewSession(ctx); err != nil { return err } + + select { + case <-worker.pollTicker.C: + // continue + case <-ctx.Done(): + return ctx.Err() + } } } +func (worker *Worker) Close() error { + var result error + for _, vm := range worker.vmm.List() { + err := vm.Stop() + if err != nil { + result = multierror.Append(result, err) + } + } + for _, vm := range worker.vmm.List() { + err := vm.Delete() + if err != nil { + result = multierror.Append(result, err) + } + } + return result +} + func (worker *Worker) runNewSession(ctx context.Context) error { subCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -113,7 +139,7 @@ func (worker *Worker) runNewSession(ctx context.Context) error { select { case <-worker.syncRequested: - case <-time.After(pollInterval): + case <-worker.pollTicker.C: // continue case <-subCtx.Done(): return subCtx.Err() @@ -168,43 +194,47 @@ func (worker *Worker) syncVMs(ctx context.Context) error { if err != nil { return err } - - worker.logger.Infof("syncing %d VMs...", len(remoteVMs)) - - // Check if we need to stop any of the VMs - for _, vmResource := range remoteVMs { - if vmResource.Status == v1.VMStatusStopping && worker.vmm.Exists(vmResource) { - if err := worker.stopVM(vmResource); err != nil { - return err - } - } + remoteVMsIndex := map[ondiskname.OnDiskName]v1.VM{} + for _, remoteVM := range remoteVMs { + remoteVMsIndex[ondiskname.NewFromResource(remoteVM)] = remoteVM } - // Handle pending VMs - for _, vmResource := range remoteVMs { - // handle pending VMs - if vmResource.Status == v1.VMStatusPending && !worker.vmm.Exists(vmResource) { - if err := worker.createVM(ctx, vmResource); err != nil { - return err - } - } - } + worker.logger.Infof("syncing %d local VMs against %d remote VMs...", + len(remoteVMsIndex), worker.vmm.Len()) - // Sync in-memory VMs for _, vm := range worker.vmm.List() { - remoteVM, ok := remoteVMs[vm.Resource.UID] + remoteVM, ok := remoteVMsIndex[vm.OnDiskName()] if !ok { - if err := worker.deleteVM(vm.Resource); err != nil { + // Remote VM was deleted, delete local VM + // + // Note: this check needs to run for each VM + // before we attempt to create any VMs below. + if err := worker.deleteVM(vm); err != nil { return err } - } else if remoteVM.Status != v1.VMStatusFailed && vm.RunError != nil { + } else if remoteVM.Status != v1.VMStatusFailed && vm.Err() != nil { + // Local VM has failed, update remote VM remoteVM.Status = v1.VMStatusFailed - remoteVM.StatusMessage = fmt.Sprintf("failed to run VM: %v", vm.RunError) - updatedVM, err := worker.client.VMs().Update(ctx, vm.Resource) - if err != nil { + remoteVM.StatusMessage = vm.Err().Error() + if _, err := worker.client.VMs().Update(ctx, remoteVM); err != nil { + return err + } + } + } + + for _, vmResource := range remoteVMsIndex { + odn := ondiskname.NewFromResource(vmResource) + + if vmResource.Status == v1.VMStatusPending && !worker.vmm.Exists(odn) { + // Remote VM was created, create local VM + if err := worker.createVM(ctx, odn, vmResource); err != nil { + return err + } + + vmResource.Status = v1.VMStatusRunning + if _, err := worker.client.VMs().Update(ctx, vmResource); err != nil { return err } - vm.Resource = *updatedVM } } @@ -216,6 +246,10 @@ func (worker *Worker) syncOnDiskVMs(ctx context.Context) error { if err != nil { return err } + remoteVMsIndex := map[ondiskname.OnDiskName]v1.VM{} + for _, remoteVM := range remoteVMs { + remoteVMsIndex[ondiskname.NewFromResource(remoteVM)] = remoteVM + } worker.logger.Infof("syncing on-disk VMs...") @@ -238,14 +272,14 @@ func (worker *Worker) syncOnDiskVMs(ctx context.Context) error { return err } - remoteVM, ok := remoteVMs[onDiskName.UID] + remoteVM, ok := remoteVMsIndex[onDiskName] if !ok { // On-disk VM doesn't exist on the controller, delete it _, _, err := tart.Tart(ctx, worker.logger, "delete", vmInfo.Name) if err != nil { return err } - } else if remoteVM.Status == v1.VMStatusRunning && !worker.vmm.Exists(v1.VM{UID: onDiskName.UID}) { + } else if remoteVM.Status == v1.VMStatusRunning && !worker.vmm.Exists(onDiskName) { // On-disk VM exist on the controller, // but we don't know about it, so // mark it as failed @@ -261,148 +295,41 @@ func (worker *Worker) syncOnDiskVMs(ctx context.Context) error { return nil } -func (worker *Worker) deleteVM(vmResource v1.VM) error { - worker.logger.Debugf("deleting VM %s (%s)", vmResource.Name, vmResource.UID) - - if !vmResource.TerminalState() { - if err := worker.stopVM(vmResource); err != nil { - return err - } - } - - // Delete VM locally, report to the controller - if worker.vmm.Exists(vmResource) { - if err := worker.vmm.Delete(vmResource); err != nil { - return err - } - } - - worker.logger.Infof("deleted VM %s (%s)", vmResource.Name, vmResource.UID) - - return nil -} - -func (worker *Worker) createVM(ctx context.Context, vmResource v1.VM) error { - worker.logger.Debugf("creating VM %s (%s)", vmResource.Name, vmResource.UID) - - // Create or update VM locally - vm, err := worker.vmm.Create(ctx, vmResource, worker.logger) - if err != nil { - vmResource.Status = v1.VMStatusFailed - vmResource.StatusMessage = fmt.Sprintf("VM creation failed: %v", err) - _, updateErr := worker.client.VMs().Update(context.Background(), vmResource) - if updateErr != nil { - worker.logger.Errorf("failed to update VM %s (%s) remotely: %s", vmResource.Name, vmResource.UID, updateErr.Error()) - } +func (worker *Worker) deleteVM(vm *vmmanager.VM) error { + if err := vm.Stop(); err != nil { return err } - worker.logger.Infof("spawned VM %s (%s)", vmResource.Name, vmResource.UID) - - vmResource.Status = v1.VMStatusRunning - _, updateErr := worker.client.VMs().Update(context.Background(), vmResource) - if updateErr != nil { - worker.logger.Errorf("failed to update VM %s (%s) remotely: %s", vmResource.Name, vmResource.UID, updateErr.Error()) + if err := vm.Delete(); err != nil { + return err } - go func() { - err := worker.execScript(vmResource, vm.Resource.StartupScript) - if err != nil { - vmResource.Status = v1.VMStatusFailed - vmResource.StatusMessage = fmt.Sprintf("failed to run script: %v", err) - _, updateErr := worker.client.VMs().Update(context.Background(), vmResource) - if updateErr != nil { - worker.logger.Errorf("failed to update VM %s (%s) remotely: %s", vmResource.Name, vmResource.UID, updateErr.Error()) - } - } - }() + worker.vmm.Delete(vm.OnDiskName()) return nil } -func (worker *Worker) execScript(vmResource v1.VM, script *v1.VMScript) error { - if script == nil { - return nil - } - vm, err := worker.vmm.Get(vmResource) +func (worker *Worker) createVM(ctx context.Context, odn ondiskname.OnDiskName, vmResource v1.VM) error { + eventStreamer := worker.client.VMs().StreamEvents(vmResource.Name) + + vm, err := vmmanager.NewVM(ctx, vmResource, eventStreamer, worker.logger) if err != nil { - return nil + return err } - eventsStreamer := worker.client.VMs().StreamEvents(vmResource.Name) - defer func() { - err := eventsStreamer.Close() - if err != nil { - worker.logger.Errorf("errored during streaming events for %s (%s): %w", vmResource.Name, vmResource.UID, err) - } - }() - err = vm.Shell(context.Background(), vmResource.Username, vmResource.Password, - script.ScriptContent, script.Env, - func(line string) { - eventsStreamer.Stream(v1.Event{ - Kind: v1.EventKindLogLine, - Timestamp: time.Now().Unix(), - Payload: line, - }) - }) - if err != nil { - worker.logger.Errorf("failed to run script for VM %s (%s): %s", vmResource.Name, vmResource.UID, err.Error()) - } - return err + worker.vmm.Put(odn, vm) + + return nil } -func (worker *Worker) stopVM(vmResource v1.VM) error { - worker.logger.Debugf("stopping VM %s (%s)", vmResource.Name, vmResource.UID) - - // Create or update VM locally - if !worker.vmm.Exists(vmResource) { - return nil - } - - shutdownScriptErr := worker.execScript(vmResource, vmResource.ShutdownScript) - stopErr := worker.vmm.Stop(vmResource) - vmResource.Status = v1.VMStatusStopped - if stopErr != nil { - vmResource.Status = v1.VMStatusFailed - vmResource.StatusMessage = fmt.Sprintf("failed to stop vm: %v", stopErr) - } - if shutdownScriptErr != nil { - vmResource.Status = v1.VMStatusFailed - vmResource.StatusMessage = fmt.Sprintf("failed to run shutdown script: %v", shutdownScriptErr) - } - - _, err := worker.client.VMs().Update(context.Background(), vmResource) - if err != nil { - worker.logger.Errorf("failed to update VM %s (%s) remotely: %s", vmResource.Name, vmResource.UID, err.Error()) - } - return stopErr -} - -func (worker *Worker) DeleteAllVMs() error { - var result error - for _, vm := range worker.vmm.List() { - err := vm.Stop() - if err != nil { - result = multierror.Append(result, err) - } - } - for _, vm := range worker.vmm.List() { - err := vm.Delete() - if err != nil { - result = multierror.Append(result, err) - } - } - return result -} - -func (worker *Worker) GPRCMetadata() metadata.MD { +func (worker *Worker) grpcMetadata() metadata.MD { return metadata.Join( worker.client.GPRCMetadata(), metadata.Pairs(rpc.MetadataWorkerNameKey, worker.name), ) } -func (worker *Worker) RequestVMSyncing() { +func (worker *Worker) requestVMSyncing() { select { case worker.syncRequested <- true: worker.logger.Debugf("Successfully requested syncing") diff --git a/pkg/client/client.go b/pkg/client/client.go index c4d8e8e..2e0726f 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -173,8 +173,9 @@ func (client *Client) request( }() if response.StatusCode != http.StatusOK { - return fmt.Errorf("%w to make a request: %d %s", - ErrFailed, response.StatusCode, http.StatusText(response.StatusCode)) + return fmt.Errorf("%w to make a request: %d %s%s", + ErrFailed, response.StatusCode, http.StatusText(response.StatusCode), + detailsFromErrorResponseBody(response.Body)) } if out != nil { @@ -191,6 +192,27 @@ func (client *Client) request( return nil } +func detailsFromErrorResponseBody(body io.Reader) string { + bodyBytes, err := io.ReadAll(body) + if err != nil { + return "" + } + + var errorResponse struct { + Message string `json:"message"` + } + + if err := json.Unmarshal(bodyBytes, &errorResponse); err != nil { + return "" + } + + if errorResponse.Message != "" { + return fmt.Sprintf(" (%s)", errorResponse.Message) + } + + return "" +} + func (client *Client) wsRequest( _ context.Context, path string, diff --git a/pkg/client/vms.go b/pkg/client/vms.go index f1985e3..9fefc67 100644 --- a/pkg/client/vms.go +++ b/pkg/client/vms.go @@ -23,23 +23,23 @@ func (service *VMsService) Create(ctx context.Context, vm *v1.VM) error { return nil } -func (service *VMsService) FindForWorker(ctx context.Context, worker string) (map[string]v1.VM, error) { +func (service *VMsService) FindForWorker(ctx context.Context, worker string) ([]v1.VM, error) { allVms, err := service.List(ctx) if err != nil { return nil, err } - var filteredVms = make(map[string]v1.VM) + var result []v1.VM for _, vmResource := range allVms { if vmResource.Worker != worker { continue } - filteredVms[vmResource.UID] = vmResource + result = append(result, vmResource) } - return filteredVms, nil + return result, nil } func (service *VMsService) List(ctx context.Context) ([]v1.VM, error) { @@ -66,20 +66,6 @@ func (service *VMsService) Get(ctx context.Context, name string) (*v1.VM, error) return &vm, nil } -func (service *VMsService) Stop(ctx context.Context, name string) (*v1.VM, error) { - var vm v1.VM - - err := service.client.request(ctx, http.MethodGet, fmt.Sprintf("vms/%s", name), - nil, &vm, nil) - if err != nil { - return nil, err - } - - vm.Status = v1.VMStatusStopping - - return service.Update(ctx, vm) -} - func (service *VMsService) Update(ctx context.Context, vm v1.VM) (*v1.VM, error) { var updatedVM v1.VM err := service.client.request(ctx, http.MethodPut, fmt.Sprintf("vms/%s", vm.Name), diff --git a/pkg/resource/v1/restart_policy.go b/pkg/resource/v1/restart_policy.go new file mode 100644 index 0000000..0c2f6da --- /dev/null +++ b/pkg/resource/v1/restart_policy.go @@ -0,0 +1,26 @@ +package v1 + +import ( + "errors" + "fmt" +) + +var ErrInvalidRestartPolicy = errors.New("invalid restart policy") + +type RestartPolicy string + +const ( + RestartPolicyNever RestartPolicy = "Never" + RestartPolicyOnFailure RestartPolicy = "OnFailure" +) + +func NewRestartPolicyFromString(s string) (RestartPolicy, error) { + switch s { + case string(RestartPolicyNever): + return RestartPolicyNever, nil + case string(RestartPolicyOnFailure): + return RestartPolicyOnFailure, nil + default: + return "", fmt.Errorf("%w %q", ErrInvalidRestartPolicy, s) + } +} diff --git a/pkg/resource/v1/restart_policy_test.go b/pkg/resource/v1/restart_policy_test.go new file mode 100644 index 0000000..0b3a032 --- /dev/null +++ b/pkg/resource/v1/restart_policy_test.go @@ -0,0 +1,26 @@ +package v1_test + +import ( + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewRestartPolicyFromString(t *testing.T) { + _, err := v1.NewRestartPolicyFromString("") + assert.Error(t, err, "empty restart policy should yield an error") + + _, err = v1.NewRestartPolicyFromString("non-existent") + assert.Error(t, err, "non-existent restart policy should yield an error") + + _, err = v1.NewRestartPolicyFromString("never") + assert.Error(t, err, "improperly capitalized but existent policy should yield an error") + + restartPolicy, err := v1.NewRestartPolicyFromString("Never") + assert.NoError(t, err, "Never policy should be parsed correctly") + assert.Equal(t, v1.RestartPolicyNever, restartPolicy) + + restartPolicy, err = v1.NewRestartPolicyFromString("OnFailure") + assert.NoError(t, err, "OnFailure policy should be parsed correctly") + assert.Equal(t, v1.RestartPolicyOnFailure, restartPolicy) +} diff --git a/pkg/resource/v1/v1.go b/pkg/resource/v1/v1.go index 1a8cc8a..f9ac759 100644 --- a/pkg/resource/v1/v1.go +++ b/pkg/resource/v1/v1.go @@ -33,10 +33,13 @@ type VM struct { // Worker field is set by the Controller to assign this VM to a specific Worker. Worker string `json:"worker"` - Username string `json:"username"` - Password string `json:"password"` - StartupScript *VMScript `json:"startup_script"` - ShutdownScript *VMScript `json:"shutdown_script"` + Username string `json:"username"` + Password string `json:"password"` + StartupScript *VMScript `json:"startup_script"` + + RestartPolicy RestartPolicy `json:"restart_policy"` + RestartedAt time.Time `json:"restarted_at"` + RestartCount uint64 `json:"restart_count"` // UID is a useful field for avoiding data races within a single Name. // @@ -67,7 +70,7 @@ type VMScript struct { } func (vm VM) TerminalState() bool { - return vm.Status == VMStatusStopped || vm.Status == VMStatusFailed + return vm.Status == VMStatusFailed } type VMStatus string @@ -83,13 +86,6 @@ const ( // VMStatusFailed is set by both the Controller and the Worker to indicate a failure // that prevented the VM resource from reaching the VMStatusRunning state. VMStatusFailed VMStatus = "failed" - - // VMStatusStopping is set by the Controller to indicate that a VM resource needs to be stopped but not deleted. - VMStatusStopping VMStatus = "stopping" - - // VMStatusStopped is set by both the Worker to indicate that a particular VM resource has been stopped successfully - // (either via API or from within a VM via `sudo shutdown -now`). - VMStatusStopped VMStatus = "stopped" ) type ControllerInfo struct {