249 lines
6.7 KiB
Go
249 lines
6.7 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/avast/retry-go/v5"
|
|
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
|
|
"github.com/cirruslabs/orchard/internal/netconncancel"
|
|
"github.com/cirruslabs/orchard/internal/proxy"
|
|
"github.com/cirruslabs/orchard/internal/responder"
|
|
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
"github.com/cirruslabs/orchard/rpc"
|
|
"github.com/coder/websocket"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/pkg/errors"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
var errPortForwardRequest = errors.New("failed to request port forwarding")
|
|
|
|
func (controller *Controller) portForwardVM(ctx *gin.Context) responder.Responder {
|
|
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
|
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
|
return responder
|
|
}
|
|
|
|
// Retrieve and parse path and query parameters
|
|
name := ctx.Param("name")
|
|
|
|
portRaw := ctx.Query("port")
|
|
port, err := strconv.ParseUint(portRaw, 10, 16)
|
|
if err != nil {
|
|
return responder.Code(http.StatusBadRequest)
|
|
}
|
|
if port < 1 || port > 65535 {
|
|
return responder.Code(http.StatusBadRequest)
|
|
}
|
|
|
|
waitRaw := ctx.DefaultQuery("wait", "10")
|
|
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
|
if err != nil {
|
|
return responder.Code(http.StatusBadRequest)
|
|
}
|
|
|
|
// Look-up the VM
|
|
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
|
defer waitContextCancel()
|
|
|
|
vm, responderImpl := controller.waitForVM(waitContext, name)
|
|
if responderImpl != nil {
|
|
return responderImpl
|
|
}
|
|
|
|
// Commence port forwarding
|
|
return controller.portForward(ctx, waitContext, vm.Worker, vm.UID, uint32(port))
|
|
}
|
|
|
|
func (controller *Controller) portForward(
|
|
ctx *gin.Context,
|
|
notifyContext context.Context,
|
|
workerName string,
|
|
vmUID string,
|
|
port uint32,
|
|
) responder.Responder {
|
|
// Request and wait for a connection with a worker
|
|
rendezvousConn, err := retry.NewWithData[net.Conn](
|
|
retry.Context(notifyContext),
|
|
retry.DelayType(retry.FixedDelay),
|
|
retry.Delay(time.Second),
|
|
retry.Attempts(0),
|
|
retry.LastErrorOnly(true),
|
|
).Do(func() (net.Conn, error) {
|
|
return controller.portForwardConnection(ctx, notifyContext, workerName, vmUID, port)
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, errPortForwardRequest) {
|
|
controller.logger.Warnf("failed to request port forwarding from the worker %s: %v",
|
|
workerName, err)
|
|
|
|
return responder.Code(http.StatusServiceUnavailable)
|
|
}
|
|
|
|
return responder.Error(err)
|
|
}
|
|
defer func() {
|
|
_ = rendezvousConn.Close()
|
|
}()
|
|
|
|
// Worker will asynchronously start port forwarding, so we wait
|
|
wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
if err != nil {
|
|
return responder.Error(err)
|
|
}
|
|
defer func() {
|
|
// Ensure that we always close the accepted WebSocket connection,
|
|
// otherwise resource leak is possible[1]
|
|
//
|
|
// [1]: https://github.com/coder/websocket/issues/445#issuecomment-2053792044
|
|
_ = wsConn.CloseNow()
|
|
}()
|
|
|
|
expectedMsgType := websocket.MessageBinary
|
|
|
|
// Backwards compatibility with older Orchard clients
|
|
// using "golang.org/x/net/websocket" package
|
|
if ctx.Request.Header.Get("User-Agent") == "" {
|
|
expectedMsgType = websocket.MessageText
|
|
}
|
|
|
|
proxyConnectionsErrCh := make(chan error, 1)
|
|
wsConnAsNetConn := websocket.NetConn(ctx, wsConn, expectedMsgType)
|
|
|
|
go func() {
|
|
proxyConnectionsErrCh <- proxy.Connections(wsConnAsNetConn, rendezvousConn)
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case err := <-proxyConnectionsErrCh:
|
|
if err != nil {
|
|
var websocketCloseError websocket.CloseError
|
|
|
|
// Normal closure from the user
|
|
if errors.As(err, &websocketCloseError) && websocketCloseError.Code == websocket.StatusNormalClosure {
|
|
return responder.Empty()
|
|
}
|
|
|
|
if errors.Is(err, context.Canceled) {
|
|
return responder.Empty()
|
|
}
|
|
|
|
if status, ok := status.FromError(err); ok && status.Code() == codes.Canceled {
|
|
return responder.Empty()
|
|
}
|
|
|
|
controller.logger.Warnf("port forwarding: failed to proxy connections: %v", err)
|
|
}
|
|
|
|
return responder.Empty()
|
|
case <-time.After(controller.pingInterval):
|
|
pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
|
|
if err := wsConn.Ping(pingCtx); err != nil {
|
|
controller.logger.Warnf("port forwarding: failed to ping the client, "+
|
|
"connection might time out: %v", err)
|
|
}
|
|
|
|
pingCtxCancel()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (controller *Controller) waitForVM(ctx context.Context, name string) (*v1.VM, responder.Responder) {
|
|
var vm *v1.VM
|
|
var err error
|
|
|
|
for {
|
|
if lookupResponder := controller.storeView(func(txn storepkg.Transaction) responder.Responder {
|
|
vm, err = txn.GetVM(name)
|
|
if err != nil {
|
|
return responder.Error(err)
|
|
}
|
|
|
|
return nil
|
|
}); lookupResponder != nil {
|
|
return nil, lookupResponder
|
|
}
|
|
|
|
if vm.TerminalState() {
|
|
return nil, responder.JSON(http.StatusExpectationFailed,
|
|
NewErrorResponse("VM is in a terminal state '%s'", vm.Status))
|
|
}
|
|
if vm.Status == v1.VMStatusRunning {
|
|
// VM is running, proceed
|
|
return vm, nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, responder.JSON(http.StatusRequestTimeout,
|
|
NewErrorResponse("VM is not running on '%s' worker", vm.Worker))
|
|
case <-time.After(1 * time.Second):
|
|
// try again
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
func (controller *Controller) portForwardConnection(
|
|
ctx context.Context,
|
|
waitContext context.Context,
|
|
workerName string,
|
|
vmUID string,
|
|
port uint32,
|
|
) (net.Conn, error) {
|
|
// Create a rendezvous connection point
|
|
rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx)
|
|
|
|
session := uuid.New().String()
|
|
boomerangConnCh, boomerangConnCancel := controller.connRendezvous.Request(rendezvousCtx, session)
|
|
|
|
cancel := func() {
|
|
boomerangConnCancel()
|
|
rendezvousCtxCancel()
|
|
}
|
|
|
|
// Send request to a worker to initiate a port forwarding connection back to us
|
|
err := controller.workerNotifier.Notify(waitContext, workerName, &rpc.WatchInstruction{
|
|
Action: &rpc.WatchInstruction_PortForwardAction{
|
|
PortForwardAction: &rpc.WatchInstruction_PortForward{
|
|
Session: session,
|
|
VmUid: vmUID,
|
|
Port: port,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
cancel()
|
|
|
|
return nil, fmt.Errorf("%w: failed to request port forwarding from the worker %s: %v",
|
|
errPortForwardRequest, workerName, err)
|
|
}
|
|
|
|
// Wait for the worker to respond
|
|
select {
|
|
case rendezvousResponse := <-boomerangConnCh:
|
|
if rendezvousResponse.ErrorMessage != "" {
|
|
cancel()
|
|
|
|
return nil, fmt.Errorf("failed to establish port forwarding session on the worker: %s",
|
|
rendezvousResponse.ErrorMessage)
|
|
}
|
|
|
|
return netconncancel.New(rendezvousResponse.Result, cancel), nil
|
|
case <-waitContext.Done():
|
|
cancel()
|
|
|
|
return nil, waitContext.Err()
|
|
}
|
|
}
|