diff --git a/internal/controller/api_vms_execute.go b/internal/controller/api_vms_execute.go index 83a6240..b2ef96f 100644 --- a/internal/controller/api_vms_execute.go +++ b/internal/controller/api_vms_execute.go @@ -67,7 +67,7 @@ func (controller *Controller) executeVM(ctx *gin.Context) responder.Responder { return responderImpl } - tunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, waitCtx, vm) + tunnel, responderImpl := controller.establishExecuteSSHTunnel(ctx, vm) if responderImpl != nil { return responderImpl } @@ -114,16 +114,18 @@ func acceptExecuteWebSocket(ctx *gin.Context) (*websocket.Conn, error) { func (controller *Controller) establishExecuteSSHTunnel( ctx context.Context, - waitCtx context.Context, vm *v1.VM, ) (net.Conn, responder.Responder) { - rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx) + tunnelCtx, tunnelCtxCancel := context.WithTimeout(ctx, executeSessionRendezvousTimeout) + defer tunnelCtxCancel() + + rendezvousCtx, rendezvousCtxCancel := context.WithCancel(tunnelCtx) session := uuid.New().String() connCh, cancelRequest := controller.connRendezvous.Request(rendezvousCtx, session) defer cancelRequest() - err := controller.workerNotifier.Notify(waitCtx, vm.Worker, &rpc.WatchInstruction{ + err := controller.workerNotifier.Notify(tunnelCtx, vm.Worker, &rpc.WatchInstruction{ Action: &rpc.WatchInstruction_PortForwardAction{ PortForwardAction: &rpc.WatchInstruction_PortForward{ Session: session, @@ -141,9 +143,6 @@ func (controller *Controller) establishExecuteSSHTunnel( return nil, responder.Code(http.StatusServiceUnavailable) } - timeoutTimer := time.NewTimer(executeSessionRendezvousTimeout) - defer timeoutTimer.Stop() - select { case rendezvousResp := <-connCh: if rendezvousResp.ErrorMessage != "" { @@ -160,13 +159,13 @@ func (controller *Controller) establishExecuteSSHTunnel( } return netconncancel.New(rendezvousResp.Result, rendezvousCtxCancel), nil - case <-timeoutTimer.C: + case <-tunnelCtx.Done(): rendezvousCtxCancel() - return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse( - "timed out waiting for worker %s to establish SSH tunnel", vm.Worker)) - case <-ctx.Done(): - rendezvousCtxCancel() + if errors.Is(tunnelCtx.Err(), context.DeadlineExceeded) { + return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse( + "timed out waiting for worker %s to establish SSH tunnel", vm.Worker)) + } return nil, responder.Error(ctx.Err()) }