package worker import ( "context" "fmt" "time" "github.com/cirruslabs/orchard/internal/proxy" "github.com/cirruslabs/orchard/internal/worker/vmmanager" "github.com/cirruslabs/orchard/rpc" "google.golang.org/grpc/keepalive" "google.golang.org/protobuf/types/known/emptypb" "net" //nolint:staticcheck // https://github.com/mitchellh/go-grpc-net-conn/pull/1 "github.com/golang/protobuf/proto" grpc_net_conn "github.com/mitchellh/go-grpc-net-conn" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/samber/lo" ) func (worker *Worker) watchRPC(ctx context.Context) error { worker.logger.Infof("connecting to %s over gRPC", worker.client.GRPCTarget()) conn, err := grpc.NewClient(worker.client.GRPCTarget(), grpc.WithTransportCredentials(worker.client.GRPCTransportCredentials()), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, }), ) if err != nil { return err } worker.logger.Infof("gRPC connection established, starting gRPC stream with the controller") client := rpc.NewControllerClient(conn) ctxWithMetadata := metadata.NewOutgoingContext(ctx, worker.grpcMetadata()) stream, err := client.Watch(ctxWithMetadata, &emptypb.Empty{}) if err != nil { return err } worker.logger.Infof("running gRPC stream with the controller") for { watchFromController, err := stream.Recv() if err != nil { return err } switch action := watchFromController.Action.(type) { case *rpc.WatchInstruction_PortForwardAction: go worker.handlePortForward(ctxWithMetadata, client, action.PortForwardAction) case *rpc.WatchInstruction_SyncVmsAction: worker.requestVMSyncing() case *rpc.WatchInstruction_ResolveIpAction: go worker.handleGetIP(ctxWithMetadata, client, action.ResolveIpAction) } } } func (worker *Worker) handlePortForward( ctx context.Context, client rpc.ControllerClient, portForwardAction *rpc.WatchInstruction_PortForward, ) { subCtx, cancel := context.WithCancel(ctx) defer cancel() grpcMetadata := metadata.Join( worker.grpcMetadata(), metadata.Pairs(rpc.MetadataWorkerPortForwardingSessionKey, portForwardAction.Session), ) ctxWithMetadata := metadata.NewOutgoingContext(subCtx, grpcMetadata) stream, err := client.PortForward(ctxWithMetadata) if err != nil { worker.logger.Warnf("port forwarding failed: failed to call PortForward() RPC method: %v", err) return } var host string if portForwardAction.VmUid == "" { // Port-forwarding request to a worker host = "localhost" } else { // Port-forwarding request to a VM, find that VM vm, ok := lo.Find(worker.vmm.List(), func(item vmmanager.VM) bool { return item.Resource().UID == portForwardAction.VmUid }) if !ok { worker.logger.Warnf("port forwarding failed: failed to get the VM: %v", err) return } // Obtain VM's IP address host, err = vm.IP(ctx) if err != nil { worker.logger.Warnf("port forwarding failed: failed to get VM's IP: %v", err) return } } // Connect to the VM's port var vmConn net.Conn if worker.dialer != nil { vmConn, err = worker.dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", host, portForwardAction.Port)) } else { dialer := net.Dialer{} vmConn, err = dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", host, portForwardAction.Port)) } if err != nil { worker.logger.Warnf("port forwarding failed: failed to connect to the VM: %v", err) return } // Proxy bytes grpcConn := &grpc_net_conn.Conn{ Stream: stream, Request: &rpc.PortForwardData{}, Response: &rpc.PortForwardData{}, Encode: grpc_net_conn.SimpleEncoder(func(message proto.Message) *[]byte { return &message.(*rpc.PortForwardData).Data }), Decode: grpc_net_conn.SimpleDecoder(func(message proto.Message) *[]byte { return &message.(*rpc.PortForwardData).Data }), } _ = proxy.Connections(vmConn, grpcConn) } func (worker *Worker) handleGetIP( ctx context.Context, client rpc.ControllerClient, resolveIP *rpc.WatchInstruction_ResolveIP, ) { grpcMetadata := metadata.Join( worker.grpcMetadata(), metadata.Pairs(rpc.MetadataWorkerPortForwardingSessionKey, resolveIP.Session), ) ctxWithMetadata := metadata.NewOutgoingContext(ctx, grpcMetadata) // Find the desired VM vm, ok := lo.Find(worker.vmm.List(), func(item vmmanager.VM) bool { return item.Resource().UID == resolveIP.VmUid }) if !ok { worker.logger.Warnf("failed to resolve IP for the VM with UID %q: VM not found", resolveIP.VmUid) return } // Obtain VM's IP address ip, err := vm.IP(ctx) if err != nil { worker.logger.Warnf("failed to resolve IP for the VM with UID %q: \"tart ip\" failed: %v", resolveIP.VmUid, err) return } _, err = client.ResolveIP(ctxWithMetadata, &rpc.ResolveIPResult{ Session: resolveIP.Session, Ip: ip, }) if err != nil { worker.logger.Warnf("failed to resolve IP for the VM with UID %q: "+ "failed to call back to the controller: %v", resolveIP.VmUid, err) return } }