orchard/internal/worker/rpcv2.go

161 lines
4.2 KiB
Go

package worker
import (
"context"
"fmt"
"net"
"github.com/cirruslabs/orchard/internal/proxy"
"github.com/cirruslabs/orchard/internal/worker/vmmanager"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/samber/lo"
)
func (worker *Worker) watchRPCV2(ctx context.Context) error {
watchInstructionCh, watchErrCh, err := worker.client.RPC().Watch(ctx, worker.name)
if err != nil {
return err
}
for {
select {
case watchInstruction := <-watchInstructionCh:
if portForwardAction := watchInstruction.PortForwardAction; portForwardAction != nil {
go worker.handlePortForwardV2(ctx, portForwardAction)
} else if syncVMsAction := watchInstruction.SyncVMsAction; syncVMsAction != nil {
worker.requestVMSyncing()
} else if resolveIPAction := watchInstruction.ResolveIPAction; resolveIPAction != nil {
go worker.handleGetIPV2(ctx, resolveIPAction)
}
case watchErr := <-watchErrCh:
return watchErr
case <-ctx.Done():
return ctx.Err()
}
}
}
func (worker *Worker) handlePortForwardV2(ctx context.Context, portForward *v1.PortForwardAction) {
var errorMessage string
worker.logger.Debugf("received port-forwarding request to VM UID %s, port %d",
portForward.VMUID, portForward.Port)
// Establish a connection with the VM
vmConn, err := worker.handlePortForwardV2Inner(ctx, portForward)
if err != nil {
errorMessage = fmt.Sprintf("port-forwarding failed: %v", err)
worker.logger.Warn(errorMessage)
}
// Respond
netConn, err := worker.client.RPC().RespondPortForward(ctx, portForward.Session, errorMessage)
if err != nil {
worker.logger.Warnf("port forwarding failed: failed to call API: %v", err)
return
}
defer func() {
// Ensure that we always close the accepted WebSocket connection,
// otherwise resource leak is possible[1]
//
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
_ = netConn.Close()
}()
// Proxy bytes if the connection was established without errors
if errorMessage == "" {
_ = proxy.Connections(vmConn, netConn)
}
}
func (worker *Worker) handlePortForwardV2Inner(
ctx context.Context,
portForward *v1.PortForwardAction,
) (net.Conn, error) {
var host string
var err error
if portForward.VMUID == "" {
// Port-forwarding request to a worker
host = "localhost"
} else {
// Port-forwarding request to a VM, find that VM
vm, ok := lo.Find(worker.vmm.List(), func(item vmmanager.VM) bool {
return item.Resource().UID == portForward.VMUID
})
if !ok {
return nil, fmt.Errorf("failed to get the VM: %v", err)
}
// Obtain VM's IP address
host, err = vm.IP(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get VM's IP: %v", err)
}
}
// Connect to the VM's port
var vmConn net.Conn
if worker.dialer != nil {
vmConn, err = worker.dialer.DialContext(ctx, "tcp",
fmt.Sprintf("%s:%d", host, portForward.Port))
} else {
dialer := net.Dialer{}
vmConn, err = dialer.DialContext(ctx, "tcp",
fmt.Sprintf("%s:%d", host, portForward.Port))
}
if err != nil {
return nil, fmt.Errorf("failed to connect to the VM: %v", err)
}
return vmConn, nil
}
func (worker *Worker) handleGetIPV2(ctx context.Context, resolveIP *v1.ResolveIPAction) {
var errorMessage string
worker.logger.Debugf("received IP resolution request to VM UID %s", resolveIP.VMUID)
// Retrieve the VM's IP
ip, err := worker.handleGetIPV2Inner(ctx, resolveIP)
if err != nil {
errorMessage = fmt.Sprintf("failed to resolve VM's IP: %v", err)
worker.logger.Warn(errorMessage)
}
// Report results
if err := worker.client.RPC().RespondIP(ctx, resolveIP.Session, ip, errorMessage); err != nil {
worker.logger.Warnf("failed to resolve IP for the VM with UID %q: "+
"failed to call back to the controller: %v", resolveIP.VMUID, err)
return
}
}
func (worker *Worker) handleGetIPV2Inner(
ctx context.Context,
resolveIP *v1.ResolveIPAction,
) (string, error) {
// Find the desired VM
vm, ok := lo.Find(worker.vmm.List(), func(item vmmanager.VM) bool {
return item.Resource().UID == resolveIP.VMUID
})
if !ok {
return "", fmt.Errorf("VM %q not found", resolveIP.VMUID)
}
// Obtain VM's IP address
ip, err := vm.IP(ctx)
if err != nil {
return "", fmt.Errorf("\"tart ip\" failed for VM %q: %v", resolveIP.VMUID, err)
}
return ip, nil
}