117 lines
2.9 KiB
Go
117 lines
2.9 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/cirruslabs/orchard/internal/proxy"
|
|
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
"github.com/cirruslabs/orchard/rpc"
|
|
"google.golang.org/grpc/keepalive"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
"time"
|
|
|
|
//nolint:staticcheck // https://github.com/mitchellh/go-grpc-net-conn/pull/1
|
|
"github.com/golang/protobuf/proto"
|
|
grpc_net_conn "github.com/mitchellh/go-grpc-net-conn"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/metadata"
|
|
"net"
|
|
)
|
|
|
|
func (worker *Worker) watchRPC(ctx context.Context) error {
|
|
conn, err := grpc.Dial(worker.client.GRPCTarget(),
|
|
grpc.WithTransportCredentials(worker.client.GRPCTransportCredentials()),
|
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
Time: 30 * time.Second,
|
|
}),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
client := rpc.NewControllerClient(conn)
|
|
|
|
ctxWithMetadata := metadata.NewOutgoingContext(ctx, worker.GPRCMetadata())
|
|
|
|
stream, err := client.Watch(ctxWithMetadata, &emptypb.Empty{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for {
|
|
watchFromController, err := stream.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
portForwardAction, ok := watchFromController.Action.(*rpc.WatchInstruction_PortForwardAction)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
go worker.handlePortForward(ctxWithMetadata, client, portForwardAction.PortForwardAction)
|
|
}
|
|
}
|
|
|
|
func (worker *Worker) handlePortForward(
|
|
ctx context.Context,
|
|
client rpc.ControllerClient,
|
|
portForwardAction *rpc.WatchInstruction_PortForward,
|
|
) {
|
|
subCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
grpcMetadata := metadata.Join(
|
|
worker.GPRCMetadata(),
|
|
metadata.Pairs(rpc.MetadataWorkerPortForwardingSessionKey, portForwardAction.Session),
|
|
)
|
|
ctxWithMetadata := metadata.NewOutgoingContext(subCtx, grpcMetadata)
|
|
stream, err := client.PortForward(ctxWithMetadata)
|
|
if err != nil {
|
|
worker.logger.Warnf("port forwarding failed: failed to call PortForward() RPC method: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
// Obtain VM
|
|
vm, err := worker.vmm.Get(v1.VM{
|
|
UID: portForwardAction.VmUid,
|
|
})
|
|
if err != nil {
|
|
worker.logger.Warnf("port forwarding failed: failed to get the VM: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
// Obtain VM's IP address
|
|
ip, err := vm.IP(ctx)
|
|
if err != nil {
|
|
worker.logger.Warnf("port forwarding failed: failed to get VM's IP: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
// Connect to the VM's port
|
|
vmConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", ip, portForwardAction.VmPort))
|
|
if err != nil {
|
|
worker.logger.Warnf("port forwarding failed: failed to connect to the VM: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
// Proxy bytes
|
|
grpcConn := &grpc_net_conn.Conn{
|
|
Stream: stream,
|
|
Request: &rpc.PortForwardData{},
|
|
Response: &rpc.PortForwardData{},
|
|
Encode: grpc_net_conn.SimpleEncoder(func(message proto.Message) *[]byte {
|
|
return &message.(*rpc.PortForwardData).Data
|
|
}),
|
|
Decode: grpc_net_conn.SimpleDecoder(func(message proto.Message) *[]byte {
|
|
return &message.(*rpc.PortForwardData).Data
|
|
}),
|
|
}
|
|
|
|
_ = proxy.Connections(vmConn, grpcConn)
|
|
}
|