189 lines
4.9 KiB
Go
189 lines
4.9 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/cirruslabs/orchard/internal/proxy"
|
|
"github.com/cirruslabs/orchard/internal/worker/vmmanager"
|
|
"github.com/cirruslabs/orchard/rpc"
|
|
"google.golang.org/grpc/keepalive"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
|
|
"net"
|
|
|
|
//nolint:staticcheck // https://github.com/mitchellh/go-grpc-net-conn/pull/1
|
|
"github.com/golang/protobuf/proto"
|
|
grpc_net_conn "github.com/mitchellh/go-grpc-net-conn"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/metadata"
|
|
|
|
"github.com/samber/lo"
|
|
)
|
|
|
|
func (worker *Worker) watchRPC(ctx context.Context) error {
|
|
worker.logger.Infof("connecting to %s over gRPC", worker.client.GRPCTarget())
|
|
|
|
conn, err := grpc.NewClient(worker.client.GRPCTarget(),
|
|
grpc.WithTransportCredentials(worker.client.GRPCTransportCredentials()),
|
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
Time: 30 * time.Second,
|
|
}),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
worker.logger.Infof("gRPC connection established, starting gRPC stream with the controller")
|
|
|
|
client := rpc.NewControllerClient(conn)
|
|
|
|
ctxWithMetadata := metadata.NewOutgoingContext(ctx, worker.grpcMetadata())
|
|
|
|
stream, err := client.Watch(ctxWithMetadata, &emptypb.Empty{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
worker.logger.Infof("running gRPC stream with the controller")
|
|
|
|
for {
|
|
watchFromController, err := stream.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch action := watchFromController.Action.(type) {
|
|
case *rpc.WatchInstruction_PortForwardAction:
|
|
go worker.handlePortForward(ctxWithMetadata, client, action.PortForwardAction)
|
|
case *rpc.WatchInstruction_SyncVmsAction:
|
|
worker.requestVMSyncing()
|
|
case *rpc.WatchInstruction_ResolveIpAction:
|
|
go worker.handleGetIP(ctxWithMetadata, client, action.ResolveIpAction)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (worker *Worker) handlePortForward(
|
|
ctx context.Context,
|
|
client rpc.ControllerClient,
|
|
portForwardAction *rpc.WatchInstruction_PortForward,
|
|
) {
|
|
subCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
grpcMetadata := metadata.Join(
|
|
worker.grpcMetadata(),
|
|
metadata.Pairs(rpc.MetadataWorkerPortForwardingSessionKey, portForwardAction.Session),
|
|
)
|
|
ctxWithMetadata := metadata.NewOutgoingContext(subCtx, grpcMetadata)
|
|
stream, err := client.PortForward(ctxWithMetadata)
|
|
if err != nil {
|
|
worker.logger.Warnf("port forwarding failed: failed to call PortForward() RPC method: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
var host string
|
|
|
|
if portForwardAction.VmUid == "" {
|
|
// Port-forwarding request to a worker
|
|
host = "localhost"
|
|
} else {
|
|
// Port-forwarding request to a VM, find that VM
|
|
vm, ok := lo.Find(worker.vmm.List(), func(item vmmanager.VM) bool {
|
|
return item.Resource().UID == portForwardAction.VmUid
|
|
})
|
|
if !ok {
|
|
worker.logger.Warnf("port forwarding failed: failed to get the VM: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
// Obtain VM's IP address
|
|
host, err = vm.IP(ctx)
|
|
if err != nil {
|
|
worker.logger.Warnf("port forwarding failed: failed to get VM's IP: %v", err)
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
// Connect to the VM's port
|
|
var vmConn net.Conn
|
|
|
|
if worker.dialer != nil {
|
|
vmConn, err = worker.dialer.DialContext(ctx, "tcp",
|
|
fmt.Sprintf("%s:%d", host, portForwardAction.Port))
|
|
} else {
|
|
dialer := net.Dialer{}
|
|
|
|
vmConn, err = dialer.DialContext(ctx, "tcp",
|
|
fmt.Sprintf("%s:%d", host, portForwardAction.Port))
|
|
}
|
|
if err != nil {
|
|
worker.logger.Warnf("port forwarding failed: failed to connect to the VM: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
// Proxy bytes
|
|
grpcConn := &grpc_net_conn.Conn{
|
|
Stream: stream,
|
|
Request: &rpc.PortForwardData{},
|
|
Response: &rpc.PortForwardData{},
|
|
Encode: grpc_net_conn.SimpleEncoder(func(message proto.Message) *[]byte {
|
|
return &message.(*rpc.PortForwardData).Data
|
|
}),
|
|
Decode: grpc_net_conn.SimpleDecoder(func(message proto.Message) *[]byte {
|
|
return &message.(*rpc.PortForwardData).Data
|
|
}),
|
|
}
|
|
|
|
_ = proxy.Connections(vmConn, grpcConn)
|
|
}
|
|
|
|
func (worker *Worker) handleGetIP(
|
|
ctx context.Context,
|
|
client rpc.ControllerClient,
|
|
resolveIP *rpc.WatchInstruction_ResolveIP,
|
|
) {
|
|
grpcMetadata := metadata.Join(
|
|
worker.grpcMetadata(),
|
|
metadata.Pairs(rpc.MetadataWorkerPortForwardingSessionKey, resolveIP.Session),
|
|
)
|
|
ctxWithMetadata := metadata.NewOutgoingContext(ctx, grpcMetadata)
|
|
|
|
// Find the desired VM
|
|
vm, ok := lo.Find(worker.vmm.List(), func(item vmmanager.VM) bool {
|
|
return item.Resource().UID == resolveIP.VmUid
|
|
})
|
|
if !ok {
|
|
worker.logger.Warnf("failed to resolve IP for the VM with UID %q: VM not found",
|
|
resolveIP.VmUid)
|
|
|
|
return
|
|
}
|
|
|
|
// Obtain VM's IP address
|
|
ip, err := vm.IP(ctx)
|
|
if err != nil {
|
|
worker.logger.Warnf("failed to resolve IP for the VM with UID %q: \"tart ip\" failed: %v",
|
|
resolveIP.VmUid, err)
|
|
|
|
return
|
|
}
|
|
|
|
_, err = client.ResolveIP(ctxWithMetadata, &rpc.ResolveIPResult{
|
|
Session: resolveIP.Session,
|
|
Ip: ip,
|
|
})
|
|
if err != nil {
|
|
worker.logger.Warnf("failed to resolve IP for the VM with UID %q: "+
|
|
"failed to call back to the controller: %v", resolveIP.VmUid, err)
|
|
|
|
return
|
|
}
|
|
}
|