161 lines
4.2 KiB
Go
161 lines
4.2 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
|
|
"github.com/cirruslabs/orchard/internal/proxy"
|
|
"github.com/cirruslabs/orchard/internal/worker/vmmanager"
|
|
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
"github.com/samber/lo"
|
|
)
|
|
|
|
func (worker *Worker) watchRPCV2(ctx context.Context) error {
|
|
watchInstructionCh, watchErrCh, err := worker.client.RPC().Watch(ctx, worker.name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case watchInstruction := <-watchInstructionCh:
|
|
if portForwardAction := watchInstruction.PortForwardAction; portForwardAction != nil {
|
|
go worker.handlePortForwardV2(ctx, portForwardAction)
|
|
} else if syncVMsAction := watchInstruction.SyncVMsAction; syncVMsAction != nil {
|
|
worker.requestVMSyncing()
|
|
} else if resolveIPAction := watchInstruction.ResolveIPAction; resolveIPAction != nil {
|
|
go worker.handleGetIPV2(ctx, resolveIPAction)
|
|
}
|
|
case watchErr := <-watchErrCh:
|
|
return watchErr
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (worker *Worker) handlePortForwardV2(ctx context.Context, portForward *v1.PortForwardAction) {
|
|
var errorMessage string
|
|
|
|
worker.logger.Debugf("received port-forwarding request to VM UID %s, port %d",
|
|
portForward.VMUID, portForward.Port)
|
|
|
|
// Establish a connection with the VM
|
|
vmConn, err := worker.handlePortForwardV2Inner(ctx, portForward)
|
|
if err != nil {
|
|
errorMessage = fmt.Sprintf("port-forwarding failed: %v", err)
|
|
|
|
worker.logger.Warn(errorMessage)
|
|
}
|
|
|
|
// Respond
|
|
netConn, err := worker.client.RPC().RespondPortForward(ctx, portForward.Session, errorMessage)
|
|
if err != nil {
|
|
worker.logger.Warnf("port forwarding failed: failed to call API: %v", err)
|
|
|
|
return
|
|
}
|
|
defer func() {
|
|
// Ensure that we always close the accepted WebSocket connection,
|
|
// otherwise resource leak is possible[1]
|
|
//
|
|
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
|
|
_ = netConn.Close()
|
|
}()
|
|
|
|
// Proxy bytes if the connection was established without errors
|
|
if errorMessage == "" {
|
|
_ = proxy.Connections(vmConn, netConn)
|
|
}
|
|
}
|
|
|
|
func (worker *Worker) handlePortForwardV2Inner(
|
|
ctx context.Context,
|
|
portForward *v1.PortForwardAction,
|
|
) (net.Conn, error) {
|
|
var host string
|
|
var err error
|
|
|
|
if portForward.VMUID == "" {
|
|
// Port-forwarding request to a worker
|
|
host = "localhost"
|
|
} else {
|
|
// Port-forwarding request to a VM, find that VM
|
|
vm, ok := lo.Find(worker.vmm.List(), func(item vmmanager.VM) bool {
|
|
return item.Resource().UID == portForward.VMUID
|
|
})
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to get the VM: %v", err)
|
|
}
|
|
|
|
// Obtain VM's IP address
|
|
host, err = vm.IP(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get VM's IP: %v", err)
|
|
}
|
|
}
|
|
|
|
// Connect to the VM's port
|
|
|
|
var vmConn net.Conn
|
|
|
|
if worker.dialer != nil {
|
|
vmConn, err = worker.dialer.DialContext(ctx, "tcp",
|
|
fmt.Sprintf("%s:%d", host, portForward.Port))
|
|
} else {
|
|
dialer := net.Dialer{}
|
|
|
|
vmConn, err = dialer.DialContext(ctx, "tcp",
|
|
fmt.Sprintf("%s:%d", host, portForward.Port))
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to the VM: %v", err)
|
|
}
|
|
|
|
return vmConn, nil
|
|
}
|
|
|
|
func (worker *Worker) handleGetIPV2(ctx context.Context, resolveIP *v1.ResolveIPAction) {
|
|
var errorMessage string
|
|
|
|
worker.logger.Debugf("received IP resolution request to VM UID %s", resolveIP.VMUID)
|
|
|
|
// Retrieve the VM's IP
|
|
ip, err := worker.handleGetIPV2Inner(ctx, resolveIP)
|
|
if err != nil {
|
|
errorMessage = fmt.Sprintf("failed to resolve VM's IP: %v", err)
|
|
|
|
worker.logger.Warn(errorMessage)
|
|
}
|
|
|
|
// Report results
|
|
if err := worker.client.RPC().RespondIP(ctx, resolveIP.Session, ip, errorMessage); err != nil {
|
|
worker.logger.Warnf("failed to resolve IP for the VM with UID %q: "+
|
|
"failed to call back to the controller: %v", resolveIP.VMUID, err)
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
func (worker *Worker) handleGetIPV2Inner(
|
|
ctx context.Context,
|
|
resolveIP *v1.ResolveIPAction,
|
|
) (string, error) {
|
|
// Find the desired VM
|
|
vm, ok := lo.Find(worker.vmm.List(), func(item vmmanager.VM) bool {
|
|
return item.Resource().UID == resolveIP.VMUID
|
|
})
|
|
if !ok {
|
|
return "", fmt.Errorf("VM %q not found", resolveIP.VMUID)
|
|
}
|
|
|
|
// Obtain VM's IP address
|
|
ip, err := vm.IP(ctx)
|
|
if err != nil {
|
|
return "", fmt.Errorf("\"tart ip\" failed for VM %q: %v", resolveIP.VMUID, err)
|
|
}
|
|
|
|
return ip, nil
|
|
}
|