From b4ee8880775e6d99bc9353e19a0668ee3d4ef951 Mon Sep 17 00:00:00 2001 From: Nikolay Edigaryev Date: Fri, 6 Mar 2026 17:16:26 +0100 Subject: [PATCH] Support Vetu virtualization on Linux in addition to Tart on macOS --- .cirrus.yml | 4 +- .github/workflows/main.yml | 24 ++ api/openapi.yaml | 42 ++- go.mod | 4 +- go.sum | 4 +- internal/controller/api_vms.go | 39 ++- internal/controller/api_vms_exec.go | 15 +- internal/controller/api_vms_ip.go | 66 +++- internal/controller/api_vms_portforward.go | 161 ++++----- internal/controller/api_workers.go | 10 + internal/controller/controller.go | 78 +++++ internal/controller/scheduler/scheduler.go | 13 +- internal/imageconstant/imageconstant.go | 5 +- internal/tests/events_pagination_test.go | 12 +- internal/tests/exec_test.go | 10 +- internal/tests/integration_test.go | 131 ++++--- internal/tests/ip_endpoint_test.go | 18 +- .../platformdependent/platformdependent.go | 59 ++++ internal/tests/spec_update_test.go | 26 +- internal/tests/sshserver_test.go | 18 +- internal/worker/option.go | 3 +- .../iokitregistry/iokitregistry_test.go | 2 + internal/worker/runtime/runtime.go | 26 ++ internal/worker/runtime/synthetic.go | 51 +++ internal/worker/runtime/tart.go | 45 +++ internal/worker/runtime/vetu.go | 45 +++ internal/worker/vmmanager/base/base.go | 173 ++++++++++ internal/worker/vmmanager/base/cmd.go | 83 +++++ internal/worker/vmmanager/tart/cmd.go | 87 +---- internal/worker/vmmanager/tart/tart.go | 166 +-------- internal/worker/vmmanager/vetu/cmd.go | 19 ++ internal/worker/vmmanager/vetu/vetu.go | 322 ++++++++++++++++++ internal/worker/vmmanager/vmmanager.go | 7 + internal/worker/worker.go | 44 +-- pkg/resource/v1/v1.go | 128 +++++++ pkg/resource/v1/worker.go | 6 + 36 files changed, 1438 insertions(+), 508 deletions(-) create mode 100644 .github/workflows/main.yml create mode 100644 internal/tests/platformdependent/platformdependent.go create mode 100644 internal/worker/runtime/runtime.go create mode 100644 internal/worker/runtime/synthetic.go create mode 100644 internal/worker/runtime/tart.go create mode 100644 internal/worker/runtime/vetu.go create mode 100644 internal/worker/vmmanager/base/cmd.go create mode 100644 internal/worker/vmmanager/vetu/cmd.go create mode 100644 internal/worker/vmmanager/vetu/vetu.go diff --git a/.cirrus.yml b/.cirrus.yml index 665284f..4a9f1d3 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -5,8 +5,8 @@ task: name: dev-mini resources: tart-vms: 2 - pull_script: - - tart pull ghcr.io/cirruslabs/macos-sonoma-base:latest + pre_pull_script: + - tart pull ghcr.io/cirruslabs/macos-tahoe-base:latest test_script: - go test -ldflags="-B gobuildid" -v -count=1 ./... always: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..5c878ac --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,24 @@ +name: Main + +on: + push: + +jobs: + test: + name: Test (Linux) + runs-on: ghcr.io/cirruslabs/ubuntu-runner-amd64:24.04-md + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: stable + - name: Install Vetu + run: | + sudo apt-get update && sudo apt-get -y install apt-transport-https ca-certificates + echo "deb [trusted=yes] https://apt.fury.io/cirruslabs/ /" | sudo tee /etc/apt/sources.list.d/cirruslabs.list + sudo apt-get update && sudo apt-get -y install vetu + - name: Pre-pull default Vetu image for use in tests + run: | + vetu pull ghcr.io/cirruslabs/ubuntu-runner-amd64:latest + - name: Run tests + run: go test -v -count=1 ./... diff --git a/api/openapi.yaml b/api/openapi.yaml index 9622964..29823a2 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -558,6 +558,32 @@ components: title: Virtual Machine Specification type: object properties: + os: + type: string + description: | + Operating system used by a VM. + + Set to `linux` to work around the Apple's limitation of 2 macOS VMs per host. + + This field cannot be changed after the VM is created. + default: darwin + enum: [ darwin, linux ] + arch: + type: string + description: | + Hardware architecture to use for a VM. + + This field cannot be changed after the VM is created. + default: arm64 + enum: [ arm64, amd64 ] + runtime: + type: string + description: | + Runtime to use for a VM. + + This field cannot be changed after the VM is created. + default: tart + enum: [ tart, vetu ] image: type: string description: VM image for this VM @@ -709,20 +735,20 @@ components: When set to `stopped` or `suspended`, the VM does not consume any `resources` and can serve as a source for creating new Orchard VMs on the same worker. See - `tartName` for more details. + `localName` for more details. Note that you can only transition into `stopped` or `suspended` only once at the moment. default: running enum: [ running, stopped, suspended ] - tartName: + localName: type: string description: | - Name of the Tart VM backing this VM resource. + Name of the local VM backing this VM resource. - `tartName` is specific to a worker, whereas `name` is cluster-wide. + `localName` is specific to a worker, whereas `name` is cluster-wide. - `tartName` is useful in combination with `powerState` for creating stopped or suspended VMs + `localName` is useful in combination with `powerState` for creating stopped or suspended VMs that can be used to start or resume new VMs on the same worker. However, with great power comes great responsibility. You need to make sure: @@ -735,6 +761,12 @@ components: new MAC addresses automatically, which will stop them from booting, since the suspend‑resume machinery expects the same MAC address readOnly: true + tartName: + type: string + description: | + Deprecated alias for `localName`. + readOnly: true + deprecated: true VMState: title: Virtual Machine State type: object diff --git a/go.mod b/go.mod index 2670c52..901ce15 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ go 1.25.1 replace github.com/gin-gonic/gin v1.11.0 => github.com/gin-gonic/gin v1.10.0 require ( - github.com/avast/retry-go v3.0.0+incompatible github.com/avast/retry-go/v4 v4.7.0 + github.com/avast/retry-go/v5 v5.0.0 github.com/cirruslabs/chacha v0.16.3 github.com/coder/websocket v1.8.14 github.com/deckarep/golang-set/v2 v2.8.0 @@ -47,6 +47,7 @@ require ( golang.org/x/net v0.51.0 golang.org/x/sync v0.19.0 golang.org/x/term v0.40.0 + golang.org/x/text v0.34.0 google.golang.org/grpc v1.79.1 google.golang.org/protobuf v1.36.11 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -140,6 +141,5 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/arch v0.23.0 // indirect golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect ) diff --git a/go.sum b/go.sum index af1feb7..614df74 100644 --- a/go.sum +++ b/go.sum @@ -21,10 +21,10 @@ github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk= -github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0= -github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= github.com/avast/retry-go/v4 v4.7.0 h1:yjDs35SlGvKwRNSykujfjdMxMhMQQM0TnIjJaHB+Zio= github.com/avast/retry-go/v4 v4.7.0/go.mod h1:ZMPDa3sY2bKgpLtap9JRUgk2yTAba7cgiFhqxY2Sg6Q= +github.com/avast/retry-go/v5 v5.0.0 h1:kf1Qc2UsTZ4qq8elDymqfbISvkyMuhgRxuJqX2NHP7k= +github.com/avast/retry-go/v5 v5.0.0/go.mod h1://d+usmKWio1agtZfS1H/ltTqwtIfBnRq9zEwjc3eH8= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= diff --git a/internal/controller/api_vms.go b/internal/controller/api_vms.go index 58c07ab..3bdf281 100644 --- a/internal/controller/api_vms.go +++ b/internal/controller/api_vms.go @@ -43,13 +43,16 @@ func (controller *Controller) createVM(ctx *gin.Context) responder.Responder { return responder.JSON(http.StatusPreconditionFailed, NewErrorResponse("VM image is empty")) } + // Provide defaults vm.Status = v1.VMStatusPending vm.CreatedAt = time.Now() vm.RestartedAt = time.Time{} vm.RestartCount = 0 vm.UID = uuid.New().String() vm.PowerState = v1.PowerStateRunning - vm.TartName = ondiskname.New(vm.Name, vm.UID, vm.RestartCount).String() + vm.LocalName = ondiskname.New(vm.Name, vm.UID, vm.RestartCount).String() + //nolint:staticcheck // yes, this is deprecated, but we still maintain it for backward compatibility + vm.TartName = vm.LocalName vm.Generation = 0 vm.ObservedGeneration = 0 vm.Conditions = []v1.Condition{ @@ -59,6 +62,17 @@ func (controller *Controller) createVM(ctx *gin.Context) responder.Responder { }, } + // Provide platform defaults + if vm.OS == "" { + vm.OS = v1.OSDarwin + } + if vm.Arch == "" { + vm.Arch = v1.ArchitectureARM64 + } + if vm.Runtime == "" { + vm.Runtime = v1.RuntimeTart + } + // Softnet-specific logic: automatically enable Softnet when NetSoftnetAllow or NetSoftnetBlock are set // and propagate deprecated and non-deprecated boolean fields into each other if vm.NetSoftnetDeprecated || vm.NetSoftnet || len(vm.NetSoftnetAllow) != 0 || len(vm.NetSoftnetBlock) != 0 { @@ -66,12 +80,17 @@ func (controller *Controller) createVM(ctx *gin.Context) responder.Responder { vm.NetSoftnet = true } - // Provide resource defaults - if vm.Resources == nil { - vm.Resources = make(v1.Resources) - } - if _, ok := vm.Resources[v1.ResourceTartVMs]; !ok { - vm.Resources[v1.ResourceTartVMs] = 1 + // Apple limits the number of macOS VMs to 2, + // so we need to provide a resource default + // (if not otherwise overridden by the user) + // to avoid a case when more than 2 VMs run + if vm.OS == v1.OSDarwin && vm.Runtime == v1.RuntimeTart { + if vm.Resources == nil { + vm.Resources = make(v1.Resources) + } + if _, ok := vm.Resources[v1.ResourceTartVMs]; !ok { + vm.Resources[v1.ResourceTartVMs] = 1 + } } // Validate image pull policy and provide a default value if it's missing @@ -148,6 +167,12 @@ func (controller *Controller) updateVMSpec(ctx *gin.Context) responder.Responder NewErrorResponse("cannot update VM in a terminal state")) } + // Platform sanity checks + if dbVM.OS != userVM.OS || dbVM.Arch != userVM.Arch || dbVM.Runtime != userVM.Runtime { + return responder.JSON(http.StatusPreconditionFailed, NewErrorResponse("\"os\", \"arch\" "+ + "and \"runtime\" fields cannot be modified")) + } + // Softnet-specific logic: automatically enable Softnet when NetSoftnetAllow or NetSoftnetBlock are set // and propagate deprecated and non-deprecated boolean fields into each other if userVM.NetSoftnetDeprecated || userVM.NetSoftnet || len(userVM.NetSoftnetAllow) != 0 || len(userVM.NetSoftnetBlock) != 0 { diff --git a/internal/controller/api_vms_exec.go b/internal/controller/api_vms_exec.go index 9dd107b..de384ca 100644 --- a/internal/controller/api_vms_exec.go +++ b/internal/controller/api_vms_exec.go @@ -6,10 +6,12 @@ import ( "errors" "fmt" "io" + "net" "net/http" "strconv" "time" + "github.com/avast/retry-go/v5" "github.com/cirruslabs/orchard/internal/controller/sshexec" "github.com/cirruslabs/orchard/internal/execstream" "github.com/cirruslabs/orchard/internal/responder" @@ -51,12 +53,19 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { } // Establish a port-forwarding connection to a VM's SSH port - portForwardConn, portForwardCancel, err := controller.portForwardConnection(ctx, waitContext, - vm.Worker, vm.UID, 22) + portForwardConn, err := retry.NewWithData[net.Conn]( + retry.Context(waitContext), + retry.DelayType(retry.FixedDelay), + retry.Delay(time.Second), + retry.Attempts(0), + retry.LastErrorOnly(true), + ).Do(func() (net.Conn, error) { + return controller.portForwardConnection(ctx, waitContext, vm.Worker, vm.UID, 22) + }) if err != nil { return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("%v", err)) } - defer portForwardCancel() + defer portForwardConn.Close() // Establish an SSH connection to a VM exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin) diff --git a/internal/controller/api_vms_ip.go b/internal/controller/api_vms_ip.go index 8e1702d..88c889f 100644 --- a/internal/controller/api_vms_ip.go +++ b/internal/controller/api_vms_ip.go @@ -2,11 +2,13 @@ package controller import ( "context" + "errors" "fmt" "net/http" "strconv" "time" + "github.com/avast/retry-go/v5" "github.com/cirruslabs/orchard/internal/responder" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/cirruslabs/orchard/rpc" @@ -14,6 +16,8 @@ import ( "github.com/google/uuid" ) +var errIPRequest = errors.New("failed to request VM's IP") + func (controller *Controller) ip(ctx *gin.Context) responder.Responder { if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite, v1.ServiceAccountRoleComputeConnect); responder != nil { @@ -38,40 +42,68 @@ func (controller *Controller) ip(ctx *gin.Context) responder.Responder { } // Send an IP resolution request and wait for the result + ip, err := retry.NewWithData[string]( + retry.Context(waitContext), + retry.DelayType(retry.FixedDelay), + retry.Delay(time.Second), + retry.Attempts(0), + retry.LastErrorOnly(true), + ).Do(func() (string, error) { + return controller.vmIP(ctx, waitContext, vm.Worker, vm.UID) + }) + if err != nil { + if errors.Is(err, errIPRequest) { + controller.logger.Warnf("failed to request VM's IP from the worker %s: %v", + vm.Worker, err) + + return responder.Code(http.StatusServiceUnavailable) + } + + return responder.Error(err) + } + + result := struct { + IP string `json:"ip"` + }{ + IP: ip, + } + + return responder.JSON(http.StatusOK, &result) +} + +func (controller *Controller) vmIP( + ctx context.Context, + waitContext context.Context, + workerName string, + vmUID string, +) (string, error) { + // Send an IP resolution request and wait for the result. session := uuid.New().String() boomerangConnCh, cancel := controller.ipRendezvous.Request(ctx, session) defer cancel() - err = controller.workerNotifier.Notify(waitContext, vm.Worker, &rpc.WatchInstruction{ + err := controller.workerNotifier.Notify(waitContext, workerName, &rpc.WatchInstruction{ Action: &rpc.WatchInstruction_ResolveIpAction{ ResolveIpAction: &rpc.WatchInstruction_ResolveIP{ Session: session, - VmUid: vm.UID, + VmUid: vmUID, }, }, }) if err != nil { - controller.logger.Warnf("failed to request VM's IP from the worker %s: %v", - vm.Worker, err) - - return responder.Code(http.StatusServiceUnavailable) + return "", fmt.Errorf("%w: failed to request VM's IP from the worker %s: %v", + errIPRequest, workerName, err) } select { case rendezvousResponse := <-boomerangConnCh: if rendezvousResponse.ErrorMessage != "" { - return responder.Error(fmt.Errorf("VM's IP resolution on the worker %s failed: %s", - vm.Worker, rendezvousResponse.ErrorMessage)) + return "", fmt.Errorf("VM's IP resolution on the worker %s failed: %s", + workerName, rendezvousResponse.ErrorMessage) } - result := struct { - IP string `json:"ip"` - }{ - IP: rendezvousResponse.Result, - } - - return responder.JSON(http.StatusOK, &result) - case <-ctx.Done(): - return responder.Error(ctx.Err()) + return rendezvousResponse.Result, nil + case <-waitContext.Done(): + return "", waitContext.Err() } } diff --git a/internal/controller/api_vms_portforward.go b/internal/controller/api_vms_portforward.go index 20d085a..fa18e88 100644 --- a/internal/controller/api_vms_portforward.go +++ b/internal/controller/api_vms_portforward.go @@ -8,6 +8,7 @@ import ( "strconv" "time" + "github.com/avast/retry-go/v5" storepkg "github.com/cirruslabs/orchard/internal/controller/store" "github.com/cirruslabs/orchard/internal/netconncancel" "github.com/cirruslabs/orchard/internal/proxy" @@ -22,6 +23,8 @@ import ( "google.golang.org/grpc/status" ) +var errPortForwardRequest = errors.New("failed to request port forwarding") + func (controller *Controller) portForwardVM(ctx *gin.Context) responder.Responder { if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite, v1.ServiceAccountRoleComputeConnect); responder != nil { @@ -67,105 +70,89 @@ func (controller *Controller) portForward( port uint32, ) responder.Responder { // Request and wait for a connection with a worker - rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx) - defer rendezvousCtxCancel() - - session := uuid.New().String() - - boomerangConnCh, cancel := controller.connRendezvous.Request(rendezvousCtx, session) - defer cancel() - - // Send request to worker to initiate port forwarding connection back to us - err := controller.workerNotifier.Notify(notifyContext, workerName, &rpc.WatchInstruction{ - Action: &rpc.WatchInstruction_PortForwardAction{ - PortForwardAction: &rpc.WatchInstruction_PortForward{ - Session: session, - VmUid: vmUID, - Port: port, - }, - }, + rendezvousConn, err := retry.NewWithData[net.Conn]( + retry.Context(notifyContext), + retry.DelayType(retry.FixedDelay), + retry.Delay(time.Second), + retry.Attempts(0), + retry.LastErrorOnly(true), + ).Do(func() (net.Conn, error) { + return controller.portForwardConnection(ctx, notifyContext, workerName, vmUID, port) }) if err != nil { - controller.logger.Warnf("failed to request port forwarding from the worker %s: %v", - workerName, err) + if errors.Is(err, errPortForwardRequest) { + controller.logger.Warnf("failed to request port forwarding from the worker %s: %v", + workerName, err) - return responder.Code(http.StatusServiceUnavailable) + return responder.Code(http.StatusServiceUnavailable) + } + + return responder.Error(err) } // Worker will asynchronously start port forwarding, so we wait - select { - case rendezvousResponse := <-boomerangConnCh: - if rendezvousResponse.ErrorMessage != "" { - return responder.Error(fmt.Errorf("failed to establish port forwarding session on the worker: %s", - rendezvousResponse.ErrorMessage)) - } + wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + if err != nil { + return responder.Error(err) + } + defer func() { + // Ensure that we always close the accepted WebSocket connection, + // otherwise resource leak is possible[1] + // + // [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044 + _ = wsConn.CloseNow() + }() - wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ - OriginPatterns: []string{"*"}, - }) - if err != nil { - return responder.Error(err) - } - defer func() { - // Ensure that we always close the accepted WebSocket connection, - // otherwise resource leak is possible[1] - // - // [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044 - _ = wsConn.CloseNow() - }() + expectedMsgType := websocket.MessageBinary - expectedMsgType := websocket.MessageBinary + // Backwards compatibility with older Orchard clients + // using "golang.org/x/net/websocket" package + if ctx.Request.Header.Get("User-Agent") == "" { + expectedMsgType = websocket.MessageText + } - // Backwards compatibility with older Orchard clients - // using "golang.org/x/net/websocket" package - if ctx.Request.Header.Get("User-Agent") == "" { - expectedMsgType = websocket.MessageText - } + proxyConnectionsErrCh := make(chan error, 1) + wsConnAsNetConn := websocket.NetConn(ctx, wsConn, expectedMsgType) - proxyConnectionsErrCh := make(chan error, 1) - wsConnAsNetConn := websocket.NetConn(ctx, wsConn, expectedMsgType) - fromWorkerConnectionWithCancel := netconncancel.New(rendezvousResponse.Result, rendezvousCtxCancel) + go func() { + proxyConnectionsErrCh <- proxy.Connections(wsConnAsNetConn, rendezvousConn) + }() - go func() { - proxyConnectionsErrCh <- proxy.Connections(wsConnAsNetConn, fromWorkerConnectionWithCancel) - }() + for { + select { + case err := <-proxyConnectionsErrCh: + if err != nil { + var websocketCloseError websocket.CloseError - for { - select { - case err := <-proxyConnectionsErrCh: - if err != nil { - var websocketCloseError websocket.CloseError - - // Normal closure from the user - if errors.As(err, &websocketCloseError) && websocketCloseError.Code == websocket.StatusNormalClosure { - return responder.Empty() - } - - if errors.Is(err, context.Canceled) { - return responder.Empty() - } - - if status, ok := status.FromError(err); ok && status.Code() == codes.Canceled { - return responder.Empty() - } - - controller.logger.Warnf("port forwarding: failed to proxy connections: %v", err) + // Normal closure from the user + if errors.As(err, &websocketCloseError) && websocketCloseError.Code == websocket.StatusNormalClosure { + return responder.Empty() } - return responder.Empty() - case <-time.After(controller.pingInterval): - pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second) - - if err := wsConn.Ping(pingCtx); err != nil { - controller.logger.Warnf("port forwarding: failed to ping the client, "+ - "connection might time out: %v", err) + if errors.Is(err, context.Canceled) { + return responder.Empty() } - pingCtxCancel() + if status, ok := status.FromError(err); ok && status.Code() == codes.Canceled { + return responder.Empty() + } + + controller.logger.Warnf("port forwarding: failed to proxy connections: %v", err) } + + return responder.Empty() + case <-time.After(controller.pingInterval): + pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second) + + if err := wsConn.Ping(pingCtx); err != nil { + controller.logger.Warnf("port forwarding: failed to ping the client, "+ + "connection might time out: %v", err) + } + + pingCtxCancel() } - case <-ctx.Done(): - return responder.Error(ctx.Err()) } } @@ -210,7 +197,7 @@ func (controller *Controller) portForwardConnection( workerName string, vmUID string, port uint32, -) (net.Conn, context.CancelFunc, error) { +) (net.Conn, error) { // Create a rendezvous connection point rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx) @@ -235,8 +222,8 @@ func (controller *Controller) portForwardConnection( if err != nil { cancel() - return nil, nil, fmt.Errorf("failed to request port forwarding from the worker %s: %v", - workerName, err) + return nil, fmt.Errorf("%w: failed to request port forwarding from the worker %s: %v", + errPortForwardRequest, workerName, err) } // Wait for the worker to respond @@ -245,14 +232,14 @@ func (controller *Controller) portForwardConnection( if rendezvousResponse.ErrorMessage != "" { cancel() - return nil, nil, fmt.Errorf("failed to establish port forwarding session on the worker: %s", + return nil, fmt.Errorf("failed to establish port forwarding session on the worker: %s", rendezvousResponse.ErrorMessage) } - return rendezvousResponse.Result, cancel, nil + return netconncancel.New(rendezvousResponse.Result, cancel), nil case <-waitContext.Done(): cancel() - return nil, nil, waitContext.Err() + return nil, waitContext.Err() } } diff --git a/internal/controller/api_workers.go b/internal/controller/api_workers.go index 009fbde..319020a 100644 --- a/internal/controller/api_workers.go +++ b/internal/controller/api_workers.go @@ -30,6 +30,14 @@ func (controller *Controller) createWorker(ctx *gin.Context) responder.Responder NewErrorResponse("worker name %v", err)) } + // Provide platform defaults + if worker.Arch == "" { + worker.Arch = v1.ArchitectureARM64 + } + if worker.Runtime == "" { + worker.Runtime = v1.RuntimeTart + } + currentTime := time.Now() if worker.LastSeen.IsZero() { worker.LastSeen = currentTime @@ -97,6 +105,8 @@ func (controller *Controller) createWorker(ctx *gin.Context) responder.Responder "from a different machine ID, delete this worker first to be able to re-create it")) } + dbWorker.Arch = worker.Arch + dbWorker.Runtime = worker.Runtime dbWorker.LastSeen = worker.LastSeen dbWorker.Resources = worker.Resources dbWorker.Labels = worker.Labels diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 3d1fec6..3983082 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -192,6 +192,16 @@ func New(opts ...Option) (*Controller, error) { ErrInitFailed, err) } + // Migrate VMs that were created before platform fields were introduced + if err := controller.vmsEnsurePlatformDefaults(); err != nil { + return nil, fmt.Errorf("%w: failed to migrate VM platform defaults: %v", ErrInitFailed, err) + } + + // Migrate workers that were created before platform fields were introduced + if err := controller.workersEnsurePlatformDefaults(); err != nil { + return nil, fmt.Errorf("%w: failed to migrate VM platform defaults: %v", ErrInitFailed, err) + } + // Metrics if err := controller.initializeMetrics(); err != nil { return nil, err @@ -200,6 +210,74 @@ func New(opts ...Option) (*Controller, error) { return controller, nil } +func (controller *Controller) vmsEnsurePlatformDefaults() error { + return controller.store.Update(func(txn storepkg.Transaction) error { + vms, err := txn.ListVMs() + if err != nil { + return err + } + + for _, vm := range vms { + updated := false + + if vm.OS == "" { + vm.OS = v1.OSDarwin + updated = true + } + if vm.Arch == "" { + vm.Arch = v1.ArchitectureARM64 + updated = true + } + if vm.Runtime == "" { + vm.Runtime = v1.RuntimeTart + updated = true + } + + if !updated { + continue + } + + if err := txn.SetVM(vm); err != nil { + return err + } + } + + return nil + }) +} + +func (controller *Controller) workersEnsurePlatformDefaults() error { + return controller.store.Update(func(txn storepkg.Transaction) error { + workers, err := txn.ListWorkers() + if err != nil { + return err + } + + for _, worker := range workers { + updated := false + + if worker.Arch == "" { + worker.Arch = v1.ArchitectureARM64 + updated = true + } + if worker.Runtime == "" { + worker.Runtime = v1.RuntimeTart + updated = true + } + + if !updated { + continue + } + + if err := txn.SetWorker(worker); err != nil { + return err + } + } + + return nil + }) +} + func (controller *Controller) ServiceAccounts() ([]v1.ServiceAccount, error) { var serviceAccounts []v1.ServiceAccount var err error diff --git a/internal/controller/scheduler/scheduler.go b/internal/controller/scheduler/scheduler.go index f6527f8..e45c06a 100644 --- a/internal/controller/scheduler/scheduler.go +++ b/internal/controller/scheduler/scheduler.go @@ -273,6 +273,7 @@ NextVM: if worker.Offline(scheduler.workerOfflineTimeout) || worker.SchedulingPaused || + !compatibleArchAndRuntime(unscheduledVM, worker) || !resourcesRemaining.CanFit(unscheduledVM.Resources) || !worker.Labels.Contains(unscheduledVM.Labels) { continue NextWorker @@ -328,6 +329,10 @@ NextVM: return ErrWorkerSchedulingSkipped } + if !compatibleArchAndRuntime(unscheduledVM, *currentWorker) { + return ErrWorkerSchedulingSkipped + } + if currentWorker.MachineID != worker.MachineID || !currentWorker.Resources.Equal(worker.Resources) { // Worker has changed @@ -433,6 +438,10 @@ func ProcessVMs(vms []v1.VM) ([]v1.VM, WorkerInfos) { return unscheduledVMs, workerToResources } +func compatibleArchAndRuntime(vm v1.VM, worker v1.Worker) bool { + return vm.Arch == worker.Arch && vm.Runtime == worker.Runtime +} + func (scheduler *Scheduler) healthCheckingLoopIteration() (int, error) { // Stats for the caller var numVMs int @@ -515,7 +524,9 @@ func (scheduler *Scheduler) healthCheckVM(txn storepkg.Transaction, vm v1.VM) er vm.ScheduledAt = time.Time{} vm.StartedAt = time.Time{} vm.PowerState = v1.PowerStateRunning - vm.TartName = ondiskname.New(vm.Name, vm.UID, vm.RestartCount).String() + vm.LocalName = ondiskname.New(vm.Name, vm.UID, vm.RestartCount).String() + //nolint:staticcheck // yes, this is deprecated, but we still maintain it for backward compatibility + vm.TartName = vm.LocalName vm.Conditions = []v1.Condition{ { Type: v1.ConditionTypeScheduled, diff --git a/internal/imageconstant/imageconstant.go b/internal/imageconstant/imageconstant.go index 325b529..1a51d1b 100644 --- a/internal/imageconstant/imageconstant.go +++ b/internal/imageconstant/imageconstant.go @@ -1,3 +1,6 @@ package imageconstant -const DefaultMacosImage = "ghcr.io/cirruslabs/macos-tahoe-base:latest" +const ( + DefaultMacosImage = "ghcr.io/cirruslabs/macos-tahoe-base:latest" + DefaultLinuxImage = "ghcr.io/cirruslabs/ubuntu-runner-amd64:latest" +) diff --git a/internal/tests/events_pagination_test.go b/internal/tests/events_pagination_test.go index 46cd8c4..6c421f0 100644 --- a/internal/tests/events_pagination_test.go +++ b/internal/tests/events_pagination_test.go @@ -8,8 +8,8 @@ import ( "net/url" "testing" - "github.com/cirruslabs/orchard/internal/imageconstant" "github.com/cirruslabs/orchard/internal/tests/devcontroller" + "github.com/cirruslabs/orchard/internal/tests/platformdependent" "github.com/cirruslabs/orchard/pkg/client" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/stretchr/testify/require" @@ -22,14 +22,8 @@ func TestListVMEventsPagination(t *testing.T) { ) ctx := context.Background() - vm := v1.VM{ - Meta: v1.Meta{Name: "test-vm"}, - Image: imageconstant.DefaultMacosImage, - CPU: 1, - Memory: 1024, - Headless: true, - } - require.NoError(t, devClient.VMs().Create(ctx, &vm)) + vm := platformdependent.VM("test-vm") + require.NoError(t, devClient.VMs().Create(ctx, vm)) events := []v1.Event{ {Kind: v1.EventKindLogLine, Timestamp: 1, Payload: "one"}, diff --git a/internal/tests/exec_test.go b/internal/tests/exec_test.go index b1ffa36..227f66d 100644 --- a/internal/tests/exec_test.go +++ b/internal/tests/exec_test.go @@ -7,8 +7,8 @@ import ( "time" "github.com/cirruslabs/orchard/internal/execstream" - "github.com/cirruslabs/orchard/internal/imageconstant" "github.com/cirruslabs/orchard/internal/tests/devcontroller" + "github.com/cirruslabs/orchard/internal/tests/platformdependent" "github.com/cirruslabs/orchard/internal/tests/wait" "github.com/cirruslabs/orchard/pkg/client" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" @@ -137,13 +137,7 @@ func prepareForExec(t *testing.T) (*client.Client, string) { vmName := "test-vm-exec-" + uuid.NewString() - err := devClient.VMs().Create(t.Context(), &v1.VM{ - Meta: v1.Meta{ - Name: vmName, - }, - Image: imageconstant.DefaultMacosImage, - Headless: true, - }) + err := devClient.VMs().Create(t.Context(), platformdependent.VM(vmName)) require.NoError(t, err) require.True(t, wait.Wait(2*time.Minute, func() bool { diff --git a/internal/tests/integration_test.go b/internal/tests/integration_test.go index 0a72ef8..a16d9de 100644 --- a/internal/tests/integration_test.go +++ b/internal/tests/integration_test.go @@ -6,6 +6,7 @@ import ( "net" "os" "path/filepath" + "runtime" "strconv" "strings" "testing" @@ -14,9 +15,10 @@ import ( "github.com/cirruslabs/orchard/internal/controller" "github.com/cirruslabs/orchard/internal/imageconstant" "github.com/cirruslabs/orchard/internal/tests/devcontroller" + "github.com/cirruslabs/orchard/internal/tests/platformdependent" "github.com/cirruslabs/orchard/internal/tests/wait" "github.com/cirruslabs/orchard/internal/worker/ondiskname" - "github.com/cirruslabs/orchard/internal/worker/vmmanager/tart" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -24,6 +26,8 @@ import ( "go.uber.org/zap" "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) func TestSingleVM(t *testing.T) { @@ -34,20 +38,14 @@ func TestSingleVM(t *testing.T) { t.Fatal(err) } assert.Equal(t, 1, len(workers)) - err = devClient.VMs().Create(context.Background(), &v1.VM{ - Meta: v1.Meta{ - Name: "test-vm", - }, - Image: imageconstant.DefaultMacosImage, - CPU: 4, - Memory: 8 * 1024, - Headless: true, - Status: v1.VMStatusPending, - StartupScript: &v1.VMScript{ - ScriptContent: "echo \"Hello, $FOO!\"\nfor i in $(seq 1 1000); do echo \"$i\"; done", - Env: map[string]string{"FOO": "Bar"}, - }, - }) + + vm := platformdependent.VM("test-vm") + vm.StartupScript = &v1.VMScript{ + ScriptContent: "echo \"Hello, $FOO!\"\nfor i in $(seq 1 1000); do echo \"$i\"; done", + Env: map[string]string{"FOO": "Bar"}, + } + + err = devClient.VMs().Create(context.Background(), vm) if err != nil { t.Fatal(err) } @@ -82,7 +80,7 @@ func TestSingleVM(t *testing.T) { assert.Contains(t, strings.Join(logLines, "\n"), strings.Join(expectedLogs, "\n")) // Ensure that the VM exists on disk before deleting it - require.True(t, hasVMByPredicate(t, func(info tart.VMInfo) bool { + require.True(t, hasVMByPredicate(t, func(info vmmanager.VMInfo) bool { return strings.Contains(info.Name, runningVM.UID) }, nil)) @@ -93,7 +91,7 @@ func TestSingleVM(t *testing.T) { assert.True(t, wait.Wait(2*time.Minute, func() bool { t.Logf("Waiting for the VM to be garbage collected...") - return !hasVMByPredicate(t, func(info tart.VMInfo) bool { + return !hasVMByPredicate(t, func(info vmmanager.VMInfo) bool { return strings.Contains(info.Name, runningVM.UID) }, nil) }), "VM was not garbage collected in a timely manner") @@ -107,19 +105,13 @@ func TestFailedStartupScript(t *testing.T) { t.Fatal(err) } assert.Equal(t, 1, len(workers)) - err = devClient.VMs().Create(context.Background(), &v1.VM{ - Meta: v1.Meta{ - Name: "test-vm", - }, - Image: imageconstant.DefaultMacosImage, - CPU: 4, - Memory: 8 * 1024, - Headless: true, - Status: v1.VMStatusPending, - StartupScript: &v1.VMScript{ - ScriptContent: "exit 123", - }, - }) + + vm := platformdependent.VM("test-vm") + vm.StartupScript = &v1.VMScript{ + ScriptContent: "set +e && exit 123", + } + + err = devClient.VMs().Create(context.Background(), vm) if err != nil { t.Fatal(err) } @@ -145,15 +137,7 @@ func TestPortForwarding(t *testing.T) { devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) // Create a generic macOS VM - err := devClient.VMs().Create(ctx, &v1.VM{ - Meta: v1.Meta{ - Name: "test-vm", - }, - Image: imageconstant.DefaultMacosImage, - CPU: 4, - Memory: 8 * 1024, - Headless: true, - }) + err := devClient.VMs().Create(ctx, platformdependent.VM("test-vm")) require.NoError(t, err) // Establish port forwarding to VMs SSH port @@ -185,9 +169,9 @@ func TestPortForwarding(t *testing.T) { sshSession, err := sshClient.NewSession() require.NoError(t, err) - unameOutput, err := sshSession.Output("uname -mo") + unameOutput, err := sshSession.Output("uname -a") require.NoError(t, err) - require.Contains(t, string(unameOutput), "Darwin arm64") + require.Contains(t, string(unameOutput), cases.Title(language.English).String(runtime.GOOS)) } // TestSchedulerHealthCheckingNonExistentWorker ensures that scheduler @@ -203,6 +187,12 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) { dummyVMName = "dummy-vm" ) + // Prepare a dummy VM + vm := platformdependent.VM(dummyVMName) + vm.Resources = map[string]uint64{ + "unique-resource": 1, + } + // Create a dummy worker that won't update it's LastSeen // timestamp, which will result in scheduler failing VMs // scheduled on that worker. @@ -220,22 +210,13 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) { v1.ResourceTartVMs: 1, "unique-resource": 1, }, + Arch: vm.Arch, + Runtime: vm.Runtime, }) require.NoError(t, err) // Create a dummy VM - err = devClient.VMs().Create(context.Background(), &v1.VM{ - Meta: v1.Meta{ - Name: dummyVMName, - }, - Image: imageconstant.DefaultMacosImage, - CPU: 4, - Memory: 8 * 1024, - Headless: true, - Resources: map[string]uint64{ - "unique-resource": 1, - }, - }) + err = devClient.VMs().Create(context.Background(), vm) require.NoError(t, err) // Wait for the dummy VM to get scheduled to a dummy worker @@ -263,7 +244,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) { }), "VM was not marked as failed in time") // Double check VM's status and status message - vm, err := devClient.VMs().Get(context.Background(), dummyVMName) + vm, err = devClient.VMs().Get(context.Background(), dummyVMName) require.NoError(t, err) require.Equal(t, v1.VMStatusFailed, vm.Status) require.Equal(t, "VM is assigned to a worker that doesn't exist anymore", vm.StatusMessage) @@ -285,6 +266,12 @@ func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) { dummyVMName = "dummy-vm" ) + // Prepare a dummy VM that will be assigned to our dummy worker + vm := platformdependent.VM(dummyVMName) + vm.Resources = map[string]uint64{ + "unique-resource": 1, + } + // Create a dummy worker that will be eventually marked as offline // because we won't update the LastSeen field _, err := devClient.Workers().Create(ctx, v1.Worker{ @@ -297,22 +284,13 @@ func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) { v1.ResourceTartVMs: 1, "unique-resource": 1, }, + Arch: vm.Arch, + Runtime: vm.Runtime, }) require.NoError(t, err) - // Create a dummy VM that will be assigned to our dummy worker - err = devClient.VMs().Create(context.Background(), &v1.VM{ - Meta: v1.Meta{ - Name: dummyVMName, - }, - Image: imageconstant.DefaultMacosImage, - CPU: 4, - Memory: 8 * 1024, - Headless: true, - Resources: map[string]uint64{ - "unique-resource": 1, - }, - }) + // Create a dummy VM + err = devClient.VMs().Create(context.Background(), vm) require.NoError(t, err) // Wait for the VM to be marked as failed @@ -337,15 +315,12 @@ func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) { // and are not present in the API anymore are garbage-collected by the Orchard Worker // at startup. func TestVMGarbageCollection(t *testing.T) { - ctx := context.Background() - logger, err := zap.NewDevelopment() require.NoError(t, err) // Create on-disk Tart VM that looks like it's managed by Orchard vmName := ondiskname.New("test", uuid.New().String(), 0).String() - _, _, err = tart.Tart(ctx, logger.Sugar(), "clone", - imageconstant.DefaultMacosImage, vmName) + err = platformdependent.CloneDefaultImage(t.Context(), logger.Sugar(), vmName) require.NoError(t, err) // Make sure that this VM exists @@ -363,6 +338,10 @@ func TestVMGarbageCollection(t *testing.T) { } func TestHostDirs(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("HostDirs is only supported on macOS with Tart") + } + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) dirToMount := t.TempDir() @@ -431,6 +410,10 @@ func TestHostDirs(t *testing.T) { } func TestHostDirsInvalidPolicy(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("HostDirs is only supported on macOS with Tart") + } + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) dirToMount := t.TempDir() @@ -468,17 +451,17 @@ func TestHostDirsInvalidPolicy(t *testing.T) { } func hasVM(t *testing.T, name string, logger *zap.Logger) bool { - return hasVMByPredicate(t, func(vmInfo tart.VMInfo) bool { + return hasVMByPredicate(t, func(vmInfo vmmanager.VMInfo) bool { return vmInfo.Name == name }, logger) } -func hasVMByPredicate(t *testing.T, predicate func(tart.VMInfo) bool, logger *zap.Logger) bool { +func hasVMByPredicate(t *testing.T, predicate func(vmmanager.VMInfo) bool, logger *zap.Logger) bool { if logger == nil { logger = zap.Must(zap.NewDevelopment()) } - vmInfos, err := tart.List(context.Background(), logger.Sugar()) + vmInfos, err := platformdependent.ListVMs(context.Background(), logger.Sugar()) require.NoError(t, err) return slices.ContainsFunc(vmInfos, predicate) diff --git a/internal/tests/ip_endpoint_test.go b/internal/tests/ip_endpoint_test.go index f3f27a7..ef6248b 100644 --- a/internal/tests/ip_endpoint_test.go +++ b/internal/tests/ip_endpoint_test.go @@ -3,15 +3,18 @@ package tests_test import ( "context" "net" + "runtime" "testing" "time" - "github.com/cirruslabs/orchard/internal/imageconstant" "github.com/cirruslabs/orchard/internal/tests/devcontroller" + "github.com/cirruslabs/orchard/internal/tests/platformdependent" "github.com/cirruslabs/orchard/internal/tests/wait" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) func TestIPEndpoint(t *testing.T) { @@ -19,16 +22,7 @@ func TestIPEndpoint(t *testing.T) { devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) // Create a VM to which we'll connect via Controller's SSH server - err := devClient.VMs().Create(context.Background(), &v1.VM{ - Meta: v1.Meta{ - Name: "test-vm", - }, - Image: imageconstant.DefaultMacosImage, - CPU: 4, - Memory: 8 * 1024, - Headless: true, - Status: v1.VMStatusPending, - }) + err := devClient.VMs().Create(context.Background(), platformdependent.VM("test-vm")) require.NoError(t, err) // Wait for the VM to start @@ -60,5 +54,5 @@ func TestIPEndpoint(t *testing.T) { output, err := sshSession.CombinedOutput("uname -a") require.NoError(t, err) - require.Contains(t, string(output), "Darwin") + require.Contains(t, string(output), cases.Title(language.English).String(runtime.GOOS)) } diff --git a/internal/tests/platformdependent/platformdependent.go b/internal/tests/platformdependent/platformdependent.go new file mode 100644 index 0000000..d9fb4d2 --- /dev/null +++ b/internal/tests/platformdependent/platformdependent.go @@ -0,0 +1,59 @@ +package platformdependent + +import ( + "context" + "runtime" + + "github.com/cirruslabs/orchard/internal/imageconstant" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + "github.com/cirruslabs/orchard/internal/worker/vmmanager/tart" + "github.com/cirruslabs/orchard/internal/worker/vmmanager/vetu" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "go.uber.org/zap" +) + +func VM(name string) *v1.VM { + vm := &v1.VM{ + Meta: v1.Meta{ + Name: name, + }, + Image: imageconstant.DefaultMacosImage, + CPU: 4, + Memory: 8 * 1024, + Headless: true, + } + + if runtime.GOOS == "linux" { + vm.Image = imageconstant.DefaultLinuxImage + vm.OS = v1.OSLinux + vm.Arch = v1.ArchitectureAMD64 + vm.Runtime = v1.RuntimeVetu + } + + return vm +} + +func CloneDefaultImage(ctx context.Context, logger *zap.SugaredLogger, destination string) error { + var err error + + if runtime.GOOS == "linux" { + _, _, err = vetu.Vetu(ctx, logger, "clone", imageconstant.DefaultLinuxImage, destination) + } else { + _, _, err = tart.Tart(ctx, logger, "clone", imageconstant.DefaultMacosImage, destination) + } + + return err +} + +func ListVMs(ctx context.Context, logger *zap.SugaredLogger) ([]vmmanager.VMInfo, error) { + var vms []vmmanager.VMInfo + var err error + + if runtime.GOOS == "linux" { + vms, err = vetu.List(ctx, logger) + } else { + vms, err = tart.List(ctx, logger) + } + + return vms, err +} diff --git a/internal/tests/spec_update_test.go b/internal/tests/spec_update_test.go index 5deafac..9a61c50 100644 --- a/internal/tests/spec_update_test.go +++ b/internal/tests/spec_update_test.go @@ -3,6 +3,7 @@ package tests import ( "context" "fmt" + "runtime" "testing" "time" @@ -10,6 +11,7 @@ import ( "github.com/cirruslabs/orchard/internal/tests/devcontroller" "github.com/cirruslabs/orchard/internal/tests/wait" "github.com/cirruslabs/orchard/internal/worker/ondiskname" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" "github.com/cirruslabs/orchard/internal/worker/vmmanager/tart" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/samber/lo" @@ -19,6 +21,10 @@ import ( ) func TestSpecUpdateSoftnet(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("Softnet is only supported on macOS with Tart") + } + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) // Create a VM @@ -82,6 +88,10 @@ func TestSpecUpdateSoftnet(t *testing.T) { } func TestSpecUpdateSoftnetSuspendable(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("Softnet is only supported on macOS with Tart") + } + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) // Create a suspendable VM with Softnet enabled @@ -150,6 +160,10 @@ func TestSpecUpdateSoftnetSuspendable(t *testing.T) { } func TestSpecUpdatePowerStateSuspend(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("VM suspension is only supported on macOS with Tart") + } + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) // Create a suspendable VM with Softnet enabled @@ -212,8 +226,8 @@ func TestSpecUpdatePowerStateSuspend(t *testing.T) { // Ensure that the VM is present and is suspended tartVMs, err := tart.List(t.Context(), zap.NewNop().Sugar()) require.NoError(t, err) - require.Contains(t, tartVMs, tart.VMInfo{ - Name: vm.TartName, + require.Contains(t, tartVMs, vmmanager.VMInfo{ + Name: vm.LocalName, Source: "local", State: "suspended", Running: false, @@ -221,6 +235,10 @@ func TestSpecUpdatePowerStateSuspend(t *testing.T) { } func TestSpecUpdatePowerStateStopped(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("VM suspension and Softnet is only supported on macOS with Tart") + } + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) // Create a suspendable VM with Softnet enabled @@ -283,8 +301,8 @@ func TestSpecUpdatePowerStateStopped(t *testing.T) { // Ensure that the VM is present and is suspended tartVMs, err := tart.List(t.Context(), zap.NewNop().Sugar()) require.NoError(t, err) - require.Contains(t, tartVMs, tart.VMInfo{ - Name: vm.TartName, + require.Contains(t, tartVMs, vmmanager.VMInfo{ + Name: vm.LocalName, Source: "local", State: "stopped", Running: false, diff --git a/internal/tests/sshserver_test.go b/internal/tests/sshserver_test.go index 234da6e..2becffc 100644 --- a/internal/tests/sshserver_test.go +++ b/internal/tests/sshserver_test.go @@ -5,18 +5,21 @@ import ( "crypto/subtle" "fmt" "net" + "runtime" "testing" "time" "github.com/cirruslabs/orchard/internal/controller" - "github.com/cirruslabs/orchard/internal/imageconstant" "github.com/cirruslabs/orchard/internal/tests/devcontroller" + "github.com/cirruslabs/orchard/internal/tests/platformdependent" "github.com/cirruslabs/orchard/internal/tests/wait" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" "golang.org/x/crypto/ssh" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) func TestSSHServer(t *testing.T) { @@ -39,16 +42,7 @@ func TestSSHServer(t *testing.T) { ) // Create a VM to which we'll connect via Controller's SSH server - err = devClient.VMs().Create(context.Background(), &v1.VM{ - Meta: v1.Meta{ - Name: "test-vm", - }, - Image: imageconstant.DefaultMacosImage, - CPU: 4, - Memory: 8 * 1024, - Headless: true, - Status: v1.VMStatusPending, - }) + err = devClient.VMs().Create(context.Background(), platformdependent.VM("test-vm")) require.NoError(t, err) // Wait for the VM to start @@ -110,5 +104,5 @@ func TestSSHServer(t *testing.T) { unameBytes, err := sshSessVM.Output("uname -a") require.NoError(t, err) - require.Contains(t, string(unameBytes), "Darwin") + require.Contains(t, string(unameBytes), cases.Title(language.English).String(runtime.GOOS)) } diff --git a/internal/worker/option.go b/internal/worker/option.go index 50378bf..f419bdd 100644 --- a/internal/worker/option.go +++ b/internal/worker/option.go @@ -2,6 +2,7 @@ package worker import ( "github.com/cirruslabs/orchard/internal/dialer" + "github.com/cirruslabs/orchard/internal/worker/runtime" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "go.uber.org/zap" ) @@ -47,7 +48,7 @@ func WithDialer(dialer dialer.Dialer) Option { func WithSynthetic() Option { return func(worker *Worker) { - worker.synthetic = true + worker.runtime = runtime.NewSynthetic() } } diff --git a/internal/worker/platform/iokitregistry/iokitregistry_test.go b/internal/worker/platform/iokitregistry/iokitregistry_test.go index 4397926..a5c6213 100644 --- a/internal/worker/platform/iokitregistry/iokitregistry_test.go +++ b/internal/worker/platform/iokitregistry/iokitregistry_test.go @@ -1,3 +1,5 @@ +//go:build darwin + package iokitregistry_test import ( diff --git a/internal/worker/runtime/runtime.go b/internal/worker/runtime/runtime.go new file mode 100644 index 0000000..77acd6b --- /dev/null +++ b/internal/worker/runtime/runtime.go @@ -0,0 +1,26 @@ +package runtime + +import ( + "context" + + "github.com/cirruslabs/orchard/internal/dialer" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + "github.com/cirruslabs/orchard/pkg/client" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "go.opentelemetry.io/otel/metric" + "go.uber.org/zap" +) + +type Runtime interface { + ID() v1.Runtime + Synthetic() bool + NewVM( + vmResource v1.VM, + eventStreamer *client.EventStreamer, + vmPullTimeHistogram metric.Float64Histogram, + dialer dialer.Dialer, + logger *zap.SugaredLogger, + ) vmmanager.VM + ListVMs(ctx context.Context, logger *zap.SugaredLogger) ([]vmmanager.VMInfo, error) + Cmd(ctx context.Context, logger *zap.SugaredLogger, args ...string) (string, string, error) +} diff --git a/internal/worker/runtime/synthetic.go b/internal/worker/runtime/synthetic.go new file mode 100644 index 0000000..94584f6 --- /dev/null +++ b/internal/worker/runtime/synthetic.go @@ -0,0 +1,51 @@ +package runtime + +import ( + "context" + "runtime" + + "github.com/cirruslabs/orchard/internal/dialer" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + syntheticpkg "github.com/cirruslabs/orchard/internal/worker/vmmanager/synthetic" + "github.com/cirruslabs/orchard/pkg/client" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "go.opentelemetry.io/otel/metric" + "go.uber.org/zap" +) + +type Synthetic struct{} + +func NewSynthetic() *Synthetic { + return &Synthetic{} +} + +func (synthetic *Synthetic) ID() v1.Runtime { + // Fake runtime depending on the OS + if runtime.GOOS == "linux" { + return v1.RuntimeVetu + } else { + return v1.RuntimeTart + } +} + +func (synthetic *Synthetic) Synthetic() bool { + return true +} + +func (synthetic *Synthetic) NewVM( + vmResource v1.VM, + eventStreamer *client.EventStreamer, + vmPullTimeHistogram metric.Float64Histogram, + _ dialer.Dialer, + logger *zap.SugaredLogger, +) vmmanager.VM { + return syntheticpkg.NewVM(vmResource, eventStreamer, vmPullTimeHistogram, logger) +} + +func (synthetic *Synthetic) ListVMs(ctx context.Context, logger *zap.SugaredLogger) ([]vmmanager.VMInfo, error) { + return nil, nil +} + +func (synthetic *Synthetic) Cmd(_ context.Context, _ *zap.SugaredLogger, _ ...string) (string, string, error) { + return "", "", nil +} diff --git a/internal/worker/runtime/tart.go b/internal/worker/runtime/tart.go new file mode 100644 index 0000000..f152206 --- /dev/null +++ b/internal/worker/runtime/tart.go @@ -0,0 +1,45 @@ +package runtime + +import ( + "context" + + "github.com/cirruslabs/orchard/internal/dialer" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + tartpkg "github.com/cirruslabs/orchard/internal/worker/vmmanager/tart" + "github.com/cirruslabs/orchard/pkg/client" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "go.opentelemetry.io/otel/metric" + "go.uber.org/zap" +) + +type Tart struct{} + +func NewTart() *Tart { + return &Tart{} +} + +func (tart *Tart) ID() v1.Runtime { + return v1.RuntimeTart +} + +func (tart *Tart) Synthetic() bool { + return false +} + +func (tart *Tart) NewVM( + vmResource v1.VM, + eventStreamer *client.EventStreamer, + vmPullTimeHistogram metric.Float64Histogram, + dialer dialer.Dialer, + logger *zap.SugaredLogger, +) vmmanager.VM { + return tartpkg.NewVM(vmResource, eventStreamer, vmPullTimeHistogram, dialer, logger) +} + +func (tart *Tart) ListVMs(ctx context.Context, logger *zap.SugaredLogger) ([]vmmanager.VMInfo, error) { + return tartpkg.List(ctx, logger) +} + +func (tart *Tart) Cmd(ctx context.Context, logger *zap.SugaredLogger, args ...string) (string, string, error) { + return tartpkg.Tart(ctx, logger, args...) +} diff --git a/internal/worker/runtime/vetu.go b/internal/worker/runtime/vetu.go new file mode 100644 index 0000000..84884d7 --- /dev/null +++ b/internal/worker/runtime/vetu.go @@ -0,0 +1,45 @@ +package runtime + +import ( + "context" + + "github.com/cirruslabs/orchard/internal/dialer" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + vetupkg "github.com/cirruslabs/orchard/internal/worker/vmmanager/vetu" + "github.com/cirruslabs/orchard/pkg/client" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "go.opentelemetry.io/otel/metric" + "go.uber.org/zap" +) + +type Vetu struct{} + +func NewVetu() *Vetu { + return &Vetu{} +} + +func (vetu *Vetu) ID() v1.Runtime { + return v1.RuntimeVetu +} + +func (vetu *Vetu) Synthetic() bool { + return false +} + +func (vetu *Vetu) NewVM( + vmResource v1.VM, + eventStreamer *client.EventStreamer, + vmPullTimeHistogram metric.Float64Histogram, + dialer dialer.Dialer, + logger *zap.SugaredLogger, +) vmmanager.VM { + return vetupkg.NewVM(vmResource, eventStreamer, vmPullTimeHistogram, dialer, logger) +} + +func (vetu *Vetu) ListVMs(ctx context.Context, logger *zap.SugaredLogger) ([]vmmanager.VMInfo, error) { + return vetupkg.List(ctx, logger) +} + +func (vetu *Vetu) Cmd(ctx context.Context, logger *zap.SugaredLogger, args ...string) (string, string, error) { + return vetupkg.Vetu(ctx, logger, args...) +} diff --git a/internal/worker/vmmanager/base/base.go b/internal/worker/vmmanager/base/base.go index 5b8d2ce..46f4f89 100644 --- a/internal/worker/vmmanager/base/base.go +++ b/internal/worker/vmmanager/base/base.go @@ -1,13 +1,28 @@ package base import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "net" + "strings" + "sync" "sync/atomic" + "time" + "github.com/avast/retry-go/v4" + "github.com/cirruslabs/orchard/internal/dialer" + "github.com/cirruslabs/orchard/pkg/client" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" mapset "github.com/deckarep/golang-set/v2" "go.uber.org/zap" + "golang.org/x/crypto/ssh" ) +var ErrVMFailed = errors.New("VM failed") + type VM struct { // Backward compatibility with v1.VM specification's "Status" field // @@ -109,3 +124,161 @@ func (vm *VM) conditionTypeToCondition(conditionType v1.ConditionType) v1.Condit State: conditionState, } } + +func (vm *VM) Shell( + ctx context.Context, + sshUser string, + sshPassword string, + script string, + env map[string]string, + consumeLine func(line string), + dialer dialer.Dialer, + getIP func(ctx context.Context) (string, error), +) error { + var sess *ssh.Session + + // Set default user and password if not provided + if sshUser == "" && sshPassword == "" { + sshUser = "admin" + sshPassword = "admin" + } + + // Configure SSH client + sshConfig := &ssh.ClientConfig{ + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + User: sshUser, + Auth: []ssh.AuthMethod{ + ssh.Password(sshPassword), + }, + } + + if err := retry.Do(func() error { + ip, err := getIP(ctx) + if err != nil { + return fmt.Errorf("failed to get VM's IP: %w", err) + } + + addr := ip + ":22" + + dialCtx, dialCtxCancel := context.WithTimeout(ctx, 5*time.Second) + defer dialCtxCancel() + + var netConn net.Conn + + if dialer != nil { + netConn, err = dialer.DialContext(dialCtx, "tcp", addr) + } else { + dialer := net.Dialer{} + + netConn, err = dialer.DialContext(dialCtx, "tcp", addr) + } + if err != nil { + return fmt.Errorf("failed to dial %s: %w", addr, err) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(netConn, addr, sshConfig) + if err != nil { + return fmt.Errorf("SSH handshake with %s failed: %w", addr, err) + } + + sshClient := ssh.NewClient(sshConn, chans, reqs) + + sess, err = sshClient.NewSession() + if err != nil { + return fmt.Errorf("failed to open an SSH session on %s: %w", addr, err) + } + + return nil + }, retry.Context(ctx), retry.OnRetry(func(n uint, err error) { + consumeLine(fmt.Sprintf("attempt %d to establish SSH connection failed: %v", n, err)) + })); err != nil { + return fmt.Errorf("failed to establish SSH connection: %w", err) + } + + // Log output from the virtual machine + stdout, err := sess.StdoutPipe() + if err != nil { + return fmt.Errorf("%w: while opening stdout pipe: %v", ErrVMFailed, err) + } + stderr, err := sess.StderrPipe() + if err != nil { + return fmt.Errorf("%w: while opening stderr pipe: %v", ErrVMFailed, err) + } + var outputReaderWG sync.WaitGroup + outputReaderWG.Add(1) + go func() { + output := io.MultiReader(stdout, stderr) + + scanner := bufio.NewScanner(output) + + for scanner.Scan() { + consumeLine(scanner.Text()) + } + outputReaderWG.Done() + }() + + stdinBuf, err := sess.StdinPipe() + if err != nil { + return fmt.Errorf("%w: while opening stdin pipe: %v", ErrVMFailed, err) + } + + // start a login shell so all the customization from ~/.zprofile will be picked up + err = sess.Shell() + if err != nil { + return fmt.Errorf("%w: failed to start a shell: %v", ErrVMFailed, err) + } + + var scriptBuilder strings.Builder + + scriptBuilder.WriteString("set -e\n") + // don't use sess.Setenv since it requires non-default SSH server configuration + for key, value := range env { + scriptBuilder.WriteString("export " + key + "=\"" + value + "\"\n") + } + scriptBuilder.WriteString(script) + scriptBuilder.WriteString("\nexit\n") + + _, err = stdinBuf.Write([]byte(scriptBuilder.String())) + if err != nil { + return fmt.Errorf("%w: failed to start script: %v", ErrVMFailed, err) + } + outputReaderWG.Wait() + return sess.Wait() +} + +func (vm *VM) RunScript( + ctx context.Context, + sshUser string, + sshPassword string, + script *v1.VMScript, + eventStreamer *client.EventStreamer, + dialer dialer.Dialer, + getIP func(ctx context.Context) (string, error), +) { + if eventStreamer != nil { + defer func() { + if err := eventStreamer.Close(); err != nil { + vm.logger.Errorf("errored during streaming events for startup script: %v", err) + } + }() + } + + consumeLine := func(line string) { + if eventStreamer == nil { + return + } + + eventStreamer.Stream(v1.Event{ + Kind: v1.EventKindLogLine, + Timestamp: time.Now().Unix(), + Payload: line, + }) + } + + err := vm.Shell(ctx, sshUser, sshPassword, script.ScriptContent, script.Env, consumeLine, dialer, getIP) + if err != nil { + vm.SetErr(fmt.Errorf("%w: failed to run startup script: %v", ErrVMFailed, err)) + } +} diff --git a/internal/worker/vmmanager/base/cmd.go b/internal/worker/vmmanager/base/cmd.go new file mode 100644 index 0000000..7355232 --- /dev/null +++ b/internal/worker/vmmanager/base/cmd.go @@ -0,0 +1,83 @@ +package base + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "os/exec" + "strings" + + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + "go.uber.org/zap" +) + +func Cmd( + ctx context.Context, + logger *zap.SugaredLogger, + commandName string, + args ...string, +) (string, string, error) { + cmd := exec.CommandContext(ctx, commandName, args...) + + var stdout, stderr bytes.Buffer + + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + logger.Debugf("running '%s %s'", commandName, strings.Join(args, " ")) + err := cmd.Run() + if err != nil { + if errors.Is(err, exec.ErrNotFound) { + return "", "", fmt.Errorf("%s command not found in PATH, make sure %s is installed", + commandName, strings.ToTitle(commandName)) + } + + if exitErr, ok := err.(*exec.ExitError); ok { + select { + case <-ctx.Done(): + // Do not log an error because it's the user's intent to cancel this VM operation + default: + logger.Warnf( + "'%s %s' failed with exit code %d: %s", + commandName, strings.Join(args, " "), + exitErr.ExitCode(), firstNonEmptyLine(stderr.String(), stdout.String()), + ) + } + + // Command failed, redefine the error to be the command-specific output + err = fmt.Errorf("%s command failed: %q", commandName, + firstNonEmptyLine(stderr.String(), stdout.String())) + } + } + + return stdout.String(), stderr.String(), err +} + +func List(ctx context.Context, logger *zap.SugaredLogger, commandName string) ([]vmmanager.VMInfo, error) { + output, _, err := Cmd(ctx, logger, commandName, "list", "--format", "json") + if err != nil { + return nil, err + } + + var entries []vmmanager.VMInfo + + if err := json.Unmarshal([]byte(output), &entries); err != nil { + return nil, err + } + + return entries, nil +} + +func firstNonEmptyLine(outputs ...string) string { + for _, output := range outputs { + for _, line := range strings.Split(output, "\n") { + if line != "" { + return line + } + } + } + + return "" +} diff --git a/internal/worker/vmmanager/tart/cmd.go b/internal/worker/vmmanager/tart/cmd.go index 08d67e2..87e5ca8 100644 --- a/internal/worker/vmmanager/tart/cmd.go +++ b/internal/worker/vmmanager/tart/cmd.go @@ -1,94 +1,19 @@ package tart import ( - "bytes" "context" - "encoding/json" - "errors" - "fmt" - "os/exec" - "strings" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + "github.com/cirruslabs/orchard/internal/worker/vmmanager/base" "go.uber.org/zap" ) const tartCommandName = "tart" -var ( - ErrTartNotFound = errors.New("tart command not found") - ErrTartFailed = errors.New("tart command returned non-zero exit code") -) - -type VMInfo struct { - Name string - Source string - State string - Running bool +func Tart(ctx context.Context, logger *zap.SugaredLogger, args ...string) (string, string, error) { + return base.Cmd(ctx, logger, tartCommandName, args...) } -func Tart( - ctx context.Context, - logger *zap.SugaredLogger, - args ...string, -) (string, string, error) { - cmd := exec.CommandContext(ctx, tartCommandName, args...) - - var stdout, stderr bytes.Buffer - - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - logger.Debugf("running '%s %s'", tartCommandName, strings.Join(args, " ")) - err := cmd.Run() - if err != nil { - if errors.Is(err, exec.ErrNotFound) { - return "", "", fmt.Errorf("%w: %s command not found in PATH, make sure Tart is installed", - ErrTartNotFound, tartCommandName) - } - - if exitErr, ok := err.(*exec.ExitError); ok { - select { - case <-ctx.Done(): - // Do not log an error because it's the user's intent to cancel this VM operation - default: - logger.Warnf( - "'%s %s' failed with exit code %d: %s", - tartCommandName, strings.Join(args, " "), - exitErr.ExitCode(), firstNonEmptyLine(stderr.String(), stdout.String()), - ) - } - - // Tart command failed, redefine the error to be the Tart-specific output - err = fmt.Errorf("%w: %q", ErrTartFailed, firstNonEmptyLine(stderr.String(), stdout.String())) - } - } - - return stdout.String(), stderr.String(), err -} - -func List(ctx context.Context, logger *zap.SugaredLogger) ([]VMInfo, error) { - output, _, err := Tart(ctx, logger, "list", "--format", "json") - if err != nil { - return nil, err - } - - var entries []VMInfo - - if err := json.Unmarshal([]byte(output), &entries); err != nil { - return nil, err - } - - return entries, nil -} - -func firstNonEmptyLine(outputs ...string) string { - for _, output := range outputs { - for _, line := range strings.Split(output, "\n") { - if line != "" { - return line - } - } - } - - return "" +func List(ctx context.Context, logger *zap.SugaredLogger) ([]vmmanager.VMInfo, error) { + return base.List(ctx, logger, tartCommandName) } diff --git a/internal/worker/vmmanager/tart/tart.go b/internal/worker/vmmanager/tart/tart.go index 36abaa8..b5f55e8 100644 --- a/internal/worker/vmmanager/tart/tart.go +++ b/internal/worker/vmmanager/tart/tart.go @@ -1,19 +1,14 @@ package tart import ( - "bufio" "context" - "errors" "fmt" - "io" - "net" "strconv" "strings" "sync" "sync/atomic" "time" - "github.com/avast/retry-go" "github.com/cirruslabs/orchard/internal/dialer" "github.com/cirruslabs/orchard/internal/worker/ondiskname" "github.com/cirruslabs/orchard/internal/worker/vmmanager/base" @@ -22,11 +17,8 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "go.uber.org/zap" - "golang.org/x/crypto/ssh" ) -var ErrVMFailed = errors.New("VM failed") - type VM struct { onDiskName ondiskname.OnDiskName resource v1.VM @@ -282,7 +274,8 @@ func (vm *VM) run(ctx context.Context, eventStreamer *client.EventStreamer) { if vm.resource.StartupScript != nil { vm.SetStatusMessage("VM started, running startup script...") - go vm.runScript(vm.resource.StartupScript, eventStreamer) + go vm.RunScript(vm.ctx, vm.resource.Username, vm.resource.Password, vm.resource.StartupScript, + eventStreamer, vm.dialer, vm.IP) } else { vm.SetStatusMessage("VM started") } @@ -325,7 +318,7 @@ func (vm *VM) run(ctx context.Context, eventStreamer *client.EventStreamer) { case <-vm.ctx.Done(): // Do not return an error because it's the user's intent to cancel this VM default: - vm.SetErr(fmt.Errorf("%w: %v", ErrVMFailed, err)) + vm.SetErr(fmt.Errorf("%w: %v", base.ErrVMFailed, err)) } return @@ -336,7 +329,7 @@ func (vm *VM) run(ctx context.Context, eventStreamer *client.EventStreamer) { // Do not return an error because it's the user's intent to cancel this VM default: if !vm.ConditionsSet().ContainsAny(v1.ConditionTypeSuspending, v1.ConditionTypeStopping) { - vm.SetErr(fmt.Errorf("%w: VM exited unexpectedly", ErrVMFailed)) + vm.SetErr(fmt.Errorf("%w: VM exited unexpectedly", base.ErrVMFailed)) } } } @@ -459,157 +452,8 @@ func (vm *VM) Delete() error { _, _, err := Tart(context.Background(), vm.logger, "delete", vm.id()) if err != nil { - return fmt.Errorf("%w: failed to delete VM: %v", ErrVMFailed, err) + return fmt.Errorf("%w: failed to delete VM: %v", base.ErrVMFailed, err) } return nil } - -func (vm *VM) shell( - ctx context.Context, - sshUser string, - sshPassword string, - script string, - env map[string]string, - consumeLine func(line string), -) error { - var sess *ssh.Session - - // Set default user and password if not provided - if sshUser == "" && sshPassword == "" { - sshUser = "admin" - sshPassword = "admin" - } - - // Configure SSH client - sshConfig := &ssh.ClientConfig{ - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - return nil - }, - User: sshUser, - Auth: []ssh.AuthMethod{ - ssh.Password(sshPassword), - }, - } - - if err := retry.Do(func() error { - ip, err := vm.IP(ctx) - if err != nil { - return fmt.Errorf("failed to get VM's IP: %w", err) - } - - addr := ip + ":22" - - dialCtx, dialCtxCancel := context.WithTimeout(ctx, 5*time.Second) - defer dialCtxCancel() - - var netConn net.Conn - - if vm.dialer != nil { - netConn, err = vm.dialer.DialContext(dialCtx, "tcp", addr) - } else { - dialer := net.Dialer{} - - netConn, err = dialer.DialContext(dialCtx, "tcp", addr) - } - if err != nil { - return fmt.Errorf("failed to dial %s: %w", addr, err) - } - - sshConn, chans, reqs, err := ssh.NewClientConn(netConn, addr, sshConfig) - if err != nil { - return fmt.Errorf("SSH handshake with %s failed: %w", addr, err) - } - - sshClient := ssh.NewClient(sshConn, chans, reqs) - - sess, err = sshClient.NewSession() - if err != nil { - return fmt.Errorf("failed to open an SSH session on %s: %w", addr, err) - } - - return nil - }, retry.Context(ctx), retry.OnRetry(func(n uint, err error) { - consumeLine(fmt.Sprintf("attempt %d to establish SSH connection failed: %v", n, err)) - })); err != nil { - return fmt.Errorf("failed to establish SSH connection: %w", err) - } - - // Log output from the virtual machine - stdout, err := sess.StdoutPipe() - if err != nil { - return fmt.Errorf("%w: while opening stdout pipe: %v", ErrVMFailed, err) - } - stderr, err := sess.StderrPipe() - if err != nil { - return fmt.Errorf("%w: while opening stderr pipe: %v", ErrVMFailed, err) - } - var outputReaderWG sync.WaitGroup - outputReaderWG.Add(1) - go func() { - output := io.MultiReader(stdout, stderr) - - scanner := bufio.NewScanner(output) - - for scanner.Scan() { - consumeLine(scanner.Text()) - } - outputReaderWG.Done() - }() - - stdinBuf, err := sess.StdinPipe() - if err != nil { - return fmt.Errorf("%w: while opening stdin pipe: %v", ErrVMFailed, err) - } - - // start a login shell so all the customization from ~/.zprofile will be picked up - err = sess.Shell() - if err != nil { - return fmt.Errorf("%w: failed to start a shell: %v", ErrVMFailed, err) - } - - var scriptBuilder strings.Builder - - scriptBuilder.WriteString("set -e\n") - // don't use sess.Setenv since it requires non-default SSH server configuration - for key, value := range env { - scriptBuilder.WriteString("export " + key + "=\"" + value + "\"\n") - } - scriptBuilder.WriteString(script) - scriptBuilder.WriteString("\nexit\n") - - _, err = stdinBuf.Write([]byte(scriptBuilder.String())) - if err != nil { - return fmt.Errorf("%w: failed to start script: %v", ErrVMFailed, err) - } - outputReaderWG.Wait() - return sess.Wait() -} - -func (vm *VM) runScript(script *v1.VMScript, eventStreamer *client.EventStreamer) { - if eventStreamer != nil { - defer func() { - if err := eventStreamer.Close(); err != nil { - vm.logger.Errorf("errored during streaming events for startup script: %v", err) - } - }() - } - - consumeLine := func(line string) { - if eventStreamer == nil { - return - } - - eventStreamer.Stream(v1.Event{ - Kind: v1.EventKindLogLine, - Timestamp: time.Now().Unix(), - Payload: line, - }) - } - - err := vm.shell(vm.ctx, vm.resource.Username, vm.resource.Password, - script.ScriptContent, script.Env, consumeLine) - if err != nil { - vm.SetErr(fmt.Errorf("%w: failed to run startup script: %v", ErrVMFailed, err)) - } -} diff --git a/internal/worker/vmmanager/vetu/cmd.go b/internal/worker/vmmanager/vetu/cmd.go new file mode 100644 index 0000000..319261d --- /dev/null +++ b/internal/worker/vmmanager/vetu/cmd.go @@ -0,0 +1,19 @@ +package vetu + +import ( + "context" + + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + "github.com/cirruslabs/orchard/internal/worker/vmmanager/base" + "go.uber.org/zap" +) + +const vetuCommandName = "vetu" + +func Vetu(ctx context.Context, logger *zap.SugaredLogger, args ...string) (string, string, error) { + return base.Cmd(ctx, logger, vetuCommandName, args...) +} + +func List(ctx context.Context, logger *zap.SugaredLogger) ([]vmmanager.VMInfo, error) { + return base.List(ctx, logger, vetuCommandName) +} diff --git a/internal/worker/vmmanager/vetu/vetu.go b/internal/worker/vmmanager/vetu/vetu.go new file mode 100644 index 0000000..f22e3e6 --- /dev/null +++ b/internal/worker/vmmanager/vetu/vetu.go @@ -0,0 +1,322 @@ +package vetu + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/cirruslabs/orchard/internal/dialer" + "github.com/cirruslabs/orchard/internal/worker/ondiskname" + "github.com/cirruslabs/orchard/internal/worker/vmmanager/base" + "github.com/cirruslabs/orchard/pkg/client" + "github.com/cirruslabs/orchard/pkg/resource/v1" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.uber.org/zap" +) + +type VM struct { + onDiskName ondiskname.OnDiskName + resource v1.VM + logger *zap.SugaredLogger + + // Image FQN feature, see https://github.com/cirruslabs/orchard/issues/164 + imageFQN atomic.Pointer[string] + + ctx context.Context + cancel context.CancelFunc + + wg *sync.WaitGroup + + dialer dialer.Dialer + + *base.VM +} + +func NewVM( + vmResource v1.VM, + eventStreamer *client.EventStreamer, + vmPullTimeHistogram metric.Float64Histogram, + dialer dialer.Dialer, + logger *zap.SugaredLogger, +) *VM { + vmContext, vmContextCancel := context.WithCancel(context.Background()) + + vm := &VM{ + onDiskName: ondiskname.NewFromResource(vmResource), + resource: vmResource, + logger: logger.With( + "vm_uid", vmResource.UID, + "vm_name", vmResource.Name, + "vm_restart_count", vmResource.RestartCount, + ), + + ctx: vmContext, + cancel: vmContextCancel, + + wg: &sync.WaitGroup{}, + + dialer: dialer, + + VM: base.NewVM(logger), + } + + vm.wg.Add(1) + + go func() { + defer vm.wg.Done() + + if vmResource.ImagePullPolicy == v1.ImagePullPolicyAlways { + vm.SetStatusMessage("pulling VM image...") + + pullStartedAt := time.Now() + + _, _, err := Vetu(vm.ctx, vm.logger, "pull", vm.resource.Image) + if err != nil { + select { + case <-vm.ctx.Done(): + // Do not return an error because it's the user's intent to cancel this VM operation + default: + vm.SetErr(fmt.Errorf("failed to pull the VM: %w", err)) + } + + return + } + + vmPullTimeHistogram.Record(vm.ctx, time.Since(pullStartedAt).Seconds(), metric.WithAttributes( + attribute.String("worker", vm.resource.Worker), + attribute.String("image", vm.resource.Image), + )) + } + + if err := vm.cloneAndConfigure(vm.ctx); err != nil { + select { + case <-vm.ctx.Done(): + // Do not return an error because it's the user's intent to cancel this VM operation + default: + vm.SetErr(fmt.Errorf("failed to clone the VM: %w", err)) + } + + return + } + + // Backward compatibility with v1.VM specification's "Status" field + vm.SetStarted(true) + + vm.ConditionsSet().Add(v1.ConditionTypeRunning) + + vm.run(vm.ctx, eventStreamer) + }() + + return vm +} + +func (vm *VM) Resource() v1.VM { + return vm.resource +} + +func (vm *VM) SetResource(vmResource v1.VM) { + vm.resource = vmResource + vm.resource.ObservedGeneration = vmResource.Generation +} + +func (vm *VM) OnDiskName() ondiskname.OnDiskName { + return vm.onDiskName +} + +func (vm *VM) ImageFQN() *string { + return vm.imageFQN.Load() +} + +func (vm *VM) id() string { + return vm.onDiskName.String() +} + +func (vm *VM) cloneAndConfigure(ctx context.Context) error { + vm.SetStatusMessage("cloning VM...") + + _, _, err := Vetu(ctx, vm.logger, "clone", vm.resource.Image, vm.id()) + if err != nil { + return err + } + + vm.ConditionsSet().Remove(v1.ConditionTypeCloning) + + // Image FQN feature, see https://github.com/cirruslabs/orchard/issues/164 + fqnRaw, _, err := Vetu(ctx, vm.logger, "fqn", vm.resource.Image) + if err == nil { + fqn := strings.TrimSpace(fqnRaw) + vm.imageFQN.Store(&fqn) + } + + // Set memory + vm.SetStatusMessage("configuring VM...") + + memory := vm.resource.AssignedMemory + + if memory == 0 { + memory = vm.resource.Memory + } + + if memory != 0 { + _, _, err = Vetu(ctx, vm.logger, "set", "--memory", + strconv.FormatUint(memory, 10), vm.id()) + if err != nil { + return err + } + } + + // Set CPU + cpu := vm.resource.AssignedCPU + + if cpu == 0 { + cpu = vm.resource.CPU + } + + if cpu != 0 { + _, _, err = Vetu(ctx, vm.logger, "set", "--cpu", + strconv.FormatUint(cpu, 10), vm.id()) + if err != nil { + return err + } + } + + if diskSize := vm.resource.DiskSize; diskSize != 0 { + _, _, err = Vetu(ctx, vm.logger, "set", "--disk-size", + strconv.FormatUint(diskSize, 10), vm.id()) + if err != nil { + return err + } + } + + return nil +} + +func (vm *VM) run(ctx context.Context, eventStreamer *client.EventStreamer) { + defer vm.ConditionsSet().RemoveAll(v1.ConditionTypeRunning, v1.ConditionTypeSuspending, v1.ConditionTypeStopping) + + // Launch the startup script goroutine as close as possible + // to the VM startup (below) to avoid "vetu ip" timing out + if vm.resource.StartupScript != nil { + vm.SetStatusMessage("VM started, running startup script...") + + go vm.RunScript(vm.ctx, vm.resource.Username, vm.resource.Password, vm.resource.StartupScript, + eventStreamer, vm.dialer, vm.IP) + } else { + vm.SetStatusMessage("VM started") + } + + var runArgs = []string{"run"} + + runArgs = append(runArgs, vm.id()) + _, _, err := Vetu(ctx, vm.logger, runArgs...) + if err != nil { + select { + case <-vm.ctx.Done(): + // Do not return an error because it's the user's intent to cancel this VM + default: + vm.SetErr(fmt.Errorf("%w: %v", base.ErrVMFailed, err)) + } + + return + } + + select { + case <-vm.ctx.Done(): + // Do not return an error because it's the user's intent to cancel this VM + default: + if !vm.ConditionsSet().ContainsAny(v1.ConditionTypeSuspending, v1.ConditionTypeStopping) { + vm.SetErr(fmt.Errorf("%w: VM exited unexpectedly", base.ErrVMFailed)) + } + } +} + +func (vm *VM) IP(ctx context.Context) (string, error) { + args := []string{"ip", "--wait", "60"} + + args = append(args, vm.id()) + + stdout, _, err := Vetu(ctx, vm.logger, args...) + if err != nil { + return "", err + } + + return strings.TrimSpace(stdout), nil +} + +func (vm *VM) Suspend() <-chan error { + errCh := make(chan error, 1) + + errCh <- fmt.Errorf("suspending Vetu VMs is not supported at the moment") + + return errCh +} + +func (vm *VM) Stop() <-chan error { + errCh := make(chan error, 1) + + select { + case <-vm.ctx.Done(): + // VM is already suspended/stopped + errCh <- nil + + return errCh + default: + // VM is still running + } + + vm.SetStatusMessage("Stopping VM") + vm.ConditionsSet().Add(v1.ConditionTypeStopping) + + go func() { + // Try to gracefully terminate the VM + _, _, _ = Vetu(context.Background(), zap.NewNop().Sugar(), "stop", "--timeout", "5", vm.id()) + + // Terminate the VM goroutine ("vetu pull", "vetu clone", "vetu run", etc.) via the context + vm.cancel() + vm.wg.Wait() + + // We don't return an error because we always terminate a VM + errCh <- nil + }() + + return errCh +} + +func (vm *VM) Start(eventStreamer *client.EventStreamer) { + vm.SetStatusMessage("Starting VM") + vm.ConditionsSet().Add(v1.ConditionTypeRunning) + + vm.cancel() + + vm.ctx, vm.cancel = context.WithCancel(context.Background()) + vm.wg.Add(1) + + go func() { + defer vm.wg.Done() + + vm.run(vm.ctx, eventStreamer) + }() +} + +func (vm *VM) Delete() error { + // Cancel all currently running Vetu invocations + // (e.g. "vetu clone", "vetu run", etc.) + vm.cancel() + + if vm.ConditionsSet().Contains(v1.ConditionTypeCloning) { + // Not cloned yet, nothing to delete + return nil + } + + _, _, err := Vetu(context.Background(), vm.logger, "delete", vm.id()) + if err != nil { + return fmt.Errorf("%w: failed to delete VM: %v", base.ErrVMFailed, err) + } + + return nil +} diff --git a/internal/worker/vmmanager/vmmanager.go b/internal/worker/vmmanager/vmmanager.go index 6c8a198..26622ad 100644 --- a/internal/worker/vmmanager/vmmanager.go +++ b/internal/worker/vmmanager/vmmanager.go @@ -26,6 +26,13 @@ type VM interface { Delete() error } +type VMInfo struct { + Name string + Source string + State string + Running bool +} + type VMManager struct { vms *xsync.Map[ondiskname.OnDiskName, VM] } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 3dd88db..7cccc00 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -9,15 +9,16 @@ import ( "slices" "time" + goruntime "runtime" + "github.com/avast/retry-go/v4" "github.com/cirruslabs/orchard/internal/dialer" "github.com/cirruslabs/orchard/internal/opentelemetry" "github.com/cirruslabs/orchard/internal/worker/dhcpleasetime" "github.com/cirruslabs/orchard/internal/worker/ondiskname" "github.com/cirruslabs/orchard/internal/worker/platform" + "github.com/cirruslabs/orchard/internal/worker/runtime" "github.com/cirruslabs/orchard/internal/worker/vmmanager" - "github.com/cirruslabs/orchard/internal/worker/vmmanager/synthetic" - "github.com/cirruslabs/orchard/internal/worker/vmmanager/tart" "github.com/cirruslabs/orchard/pkg/client" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/cirruslabs/orchard/rpc" @@ -54,7 +55,7 @@ type Worker struct { defaultCPU uint64 defaultMemory uint64 - synthetic bool + runtime runtime.Runtime vmPullTimeHistogram metric.Float64Histogram @@ -90,8 +91,18 @@ func New(client *client.Client, opts ...Option) (*Worker, error) { worker.name += worker.nameSuffix } - defaultResources := v1.Resources{ - v1.ResourceTartVMs: 2, + if worker.runtime == nil { + if goruntime.GOOS == "linux" { + worker.runtime = runtime.NewVetu() + } else { + worker.runtime = runtime.NewTart() + } + } + + defaultResources := v1.Resources{} + + if worker.runtime.ID() == v1.RuntimeTart { + defaultResources[v1.ResourceTartVMs] = 2 } // Determine the number of the host's logical CPU cores @@ -130,7 +141,7 @@ func New(client *client.Client, opts ...Option) (*Worker, error) { } func (worker *Worker) Run(ctx context.Context) error { - if !worker.synthetic { + if worker.runtime.ID() == v1.RuntimeTart { if err := dhcpleasetime.Check(); err != nil { worker.logger.Warnf("%v", err) } @@ -278,6 +289,8 @@ func (worker *Worker) registerWorker(ctx context.Context) error { Meta: v1.Meta{ Name: worker.name, }, + Arch: v1.Architecture(goruntime.GOARCH), + Runtime: worker.runtime.ID(), Resources: worker.resources, Labels: worker.labels, LastSeen: time.Now(), @@ -502,7 +515,7 @@ func (worker *Worker) syncVMs(ctx context.Context, updateVM func(context.Context //nolint:nestif,gocognit // complexity is tolerable for now func (worker *Worker) syncOnDiskVMs(ctx context.Context) error { - if worker.synthetic { + if worker.runtime.Synthetic() { // There's no on-disk VMs when using synthetic VMs return nil } @@ -518,7 +531,7 @@ func (worker *Worker) syncOnDiskVMs(ctx context.Context) error { worker.logger.Infof("syncing on-disk VMs...") - vmInfos, err := tart.List(ctx, worker.logger) + vmInfos, err := worker.runtime.ListVMs(ctx, worker.logger) if err != nil { return err } @@ -543,13 +556,13 @@ func (worker *Worker) syncOnDiskVMs(ctx context.Context) error { // On-disk VM doesn't exist on the controller nor in the Worker's VM manager, // stop it (if applicable) and delete it if vmInfo.Running { - _, _, err := tart.Tart(ctx, worker.logger, "stop", vmInfo.Name) + _, _, err := worker.runtime.Cmd(ctx, worker.logger, "stop", vmInfo.Name) if err != nil { worker.logger.Warnf("failed to stop") } } - _, _, err := tart.Tart(ctx, worker.logger, "delete", vmInfo.Name) + _, _, err := worker.runtime.Cmd(ctx, worker.logger, "delete", vmInfo.Name) if err != nil { return err } @@ -558,7 +571,7 @@ func (worker *Worker) syncOnDiskVMs(ctx context.Context) error { // but we've lost track of it, so shut it down (if applicable) // and report the error (if not failed yet) if vmInfo.Running { - _, _, err := tart.Tart(ctx, worker.logger, "stop", vmInfo.Name) + _, _, err := worker.runtime.Cmd(ctx, worker.logger, "stop", vmInfo.Name) if err != nil { worker.logger.Warnf("failed to stop") } @@ -584,14 +597,7 @@ func (worker *Worker) deleteVM(vm vmmanager.VM) error { func (worker *Worker) createVM(odn ondiskname.OnDiskName, vmResource v1.VM) { eventStreamer := worker.client.VMs().StreamEvents(vmResource.Name) - var vm vmmanager.VM - - if worker.synthetic { - vm = synthetic.NewVM(vmResource, eventStreamer, worker.vmPullTimeHistogram, worker.logger) - } else { - vm = tart.NewVM(vmResource, eventStreamer, worker.vmPullTimeHistogram, - worker.dialer, worker.logger) - } + vm := worker.runtime.NewVM(vmResource, eventStreamer, worker.vmPullTimeHistogram, worker.dialer, worker.logger) worker.vmm.Put(odn, vm) } diff --git a/pkg/resource/v1/v1.go b/pkg/resource/v1/v1.go index 4aca671..7feabc0 100644 --- a/pkg/resource/v1/v1.go +++ b/pkg/resource/v1/v1.go @@ -1,6 +1,8 @@ package v1 import ( + "encoding/json" + "fmt" "time" ) @@ -136,6 +138,24 @@ func (vm *VM) IsScheduled() bool { } type VMSpec struct { + // OS defines the operating system used by a VM. + // + // Previously only Darwin was supported, + // so this field defaults to that when not set. + OS OS `json:"os,omitempty"` + + // Arch defines the hardware architecture to use for a VM. + // + // Previously, only Tart on arm64 was supported, + // so this field defaults to that when not set. + Arch Architecture `json:"arch,omitempty"` + + // Runtime defines the runtime to use for a VM. + // + // Previously, only Tart on arm64 was supported, + // so this field defaults to that when not set. + Runtime Runtime `json:"runtime,omitempty"` + NetSoftnetDeprecated bool `json:"net-softnet,omitempty"` NetSoftnet bool `json:"netSoftnet,omitempty"` NetSoftnetAllow []string `json:"netSoftnetAllow,omitempty"` @@ -145,6 +165,9 @@ type VMSpec struct { } type VMSpecReadOnly struct { + LocalName string `json:"localName,omitempty"` + + // Deprecated: use LocalName instead. TartName string `json:"tartName,omitempty"` } @@ -156,6 +179,111 @@ type VMState struct { Conditions []Condition `json:"conditions,omitempty"` } +type OS string + +const ( + OSDarwin OS = "darwin" + OSLinux OS = "linux" +) + +func NewOSFromString(osRaw string) (OS, error) { + switch osRaw { + case "", string(OSDarwin): + return OSDarwin, nil + case string(OSLinux): + return OSLinux, nil + default: + return "", fmt.Errorf("unsupported OS: %q", osRaw) + } +} + +func (os *OS) UnmarshalJSON(data []byte) error { + var osRaw string + + if err := json.Unmarshal(data, &osRaw); err != nil { + return err + } + + parsedOS, err := NewOSFromString(osRaw) + if err != nil { + return err + } + + *os = parsedOS + + return nil +} + +type Architecture string + +const ( + ArchitectureARM64 Architecture = "arm64" + ArchitectureAMD64 Architecture = "amd64" +) + +func NewArchitectureFromString(rawArch string) (Architecture, error) { + switch rawArch { + case "", string(ArchitectureARM64): + return ArchitectureARM64, nil + case string(ArchitectureAMD64): + return ArchitectureAMD64, nil + default: + return "", fmt.Errorf("unsupported architecture: %q", rawArch) + } +} + +func (arch *Architecture) UnmarshalJSON(data []byte) error { + var rawArch string + + if err := json.Unmarshal(data, &rawArch); err != nil { + return err + } + + parsedArch, err := NewArchitectureFromString(rawArch) + if err != nil { + return err + } + + *arch = parsedArch + + return nil +} + +type Runtime string + +const ( + RuntimeTart Runtime = "tart" + RuntimeVetu Runtime = "vetu" +) + +func NewRuntimeFromString(rawRuntime string) (Runtime, error) { + switch rawRuntime { + case "", string(RuntimeTart): + return RuntimeTart, nil + case string(RuntimeVetu): + return RuntimeVetu, nil + default: + return "", fmt.Errorf("unsupported runtime: %q", rawRuntime) + } +} + +func (runtime *Runtime) UnmarshalJSON(data []byte) error { + var rawRuntime string + + if err := json.Unmarshal(data, &rawRuntime); err != nil { + return err + } + + parsedRuntime, err := NewRuntimeFromString(rawRuntime) + if err != nil { + return err + } + + *runtime = parsedRuntime + + return nil +} + type PowerState string const ( diff --git a/pkg/resource/v1/worker.go b/pkg/resource/v1/worker.go index 3741170..3e9c3f6 100644 --- a/pkg/resource/v1/worker.go +++ b/pkg/resource/v1/worker.go @@ -24,6 +24,12 @@ type Worker struct { // when it doesn't explicitly request a specific amount. DefaultMemory uint64 `json:"defaultMemory,omitempty"` + // Arch defines worker's hardware architecture. + Arch Architecture `json:"arch,omitempty"` + + // Runtime defines a runtime provided by this worker. + Runtime Runtime `json:"runtime,omitempty"` + Meta }