fix: keep execute tunnel context alive through session
- keep rendezvous context rooted in request context instead of timeout context - limit timeout context to notify + initial rendezvous wait only - add regression test for proxy context lifecycle (comment 2782213336) 🤖 Generated with [Codex](https://chatgpt.com/codex) Co-Authored-By: Codex <codex@openai.com>
This commit is contained in:
parent
e17a80bb95
commit
86248be003
|
|
@ -116,16 +116,16 @@ func (controller *Controller) establishExecuteSSHTunnel(
|
|||
ctx context.Context,
|
||||
vm *v1.VM,
|
||||
) (net.Conn, responder.Responder) {
|
||||
tunnelCtx, tunnelCtxCancel := context.WithTimeout(ctx, executeSessionRendezvousTimeout)
|
||||
defer tunnelCtxCancel()
|
||||
tunnelWaitCtx, tunnelWaitCtxCancel := context.WithTimeout(ctx, executeSessionRendezvousTimeout)
|
||||
defer tunnelWaitCtxCancel()
|
||||
|
||||
rendezvousCtx, rendezvousCtxCancel := context.WithCancel(tunnelCtx)
|
||||
rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx)
|
||||
|
||||
session := uuid.New().String()
|
||||
connCh, cancelRequest := controller.connRendezvous.Request(rendezvousCtx, session)
|
||||
defer cancelRequest()
|
||||
|
||||
err := controller.workerNotifier.Notify(tunnelCtx, vm.Worker, &rpc.WatchInstruction{
|
||||
err := controller.workerNotifier.Notify(tunnelWaitCtx, vm.Worker, &rpc.WatchInstruction{
|
||||
Action: &rpc.WatchInstruction_PortForwardAction{
|
||||
PortForwardAction: &rpc.WatchInstruction_PortForward{
|
||||
Session: session,
|
||||
|
|
@ -159,15 +159,15 @@ func (controller *Controller) establishExecuteSSHTunnel(
|
|||
}
|
||||
|
||||
return netconncancel.New(rendezvousResp.Result, rendezvousCtxCancel), nil
|
||||
case <-tunnelCtx.Done():
|
||||
case <-tunnelWaitCtx.Done():
|
||||
rendezvousCtxCancel()
|
||||
|
||||
if errors.Is(tunnelCtx.Err(), context.DeadlineExceeded) {
|
||||
if errors.Is(tunnelWaitCtx.Err(), context.DeadlineExceeded) {
|
||||
return nil, responder.JSON(http.StatusServiceUnavailable, NewErrorResponse(
|
||||
"timed out waiting for worker %s to establish SSH tunnel", vm.Worker))
|
||||
}
|
||||
|
||||
return nil, responder.Error(ctx.Err())
|
||||
return nil, responder.Error(tunnelWaitCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,16 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/controller/notifier"
|
||||
"github.com/cirruslabs/orchard/internal/controller/rendezvous"
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type recordingWriteCloser struct {
|
||||
|
|
@ -165,3 +170,83 @@ func TestBuildSSHCommandQuotesArguments(t *testing.T) {
|
|||
|
||||
require.Equal(t, "'echo' 'hello world' 'a'\\''b' ''", result)
|
||||
}
|
||||
|
||||
func TestEstablishExecuteSSHTunnelKeepsProxyContextAliveUntilTunnelClosed(t *testing.T) {
|
||||
controller := &Controller{
|
||||
logger: zap.NewNop().Sugar(),
|
||||
workerNotifier: notifier.NewNotifier(zap.NewNop().Sugar()),
|
||||
connRendezvous: rendezvous.New[rendezvous.ResultWithErrorMessage[net.Conn]](),
|
||||
}
|
||||
|
||||
workerCh, cancelWorker := controller.workerNotifier.Register(context.Background(), "worker-1")
|
||||
defer cancelWorker()
|
||||
|
||||
proxyCtxCh := make(chan context.Context, 1)
|
||||
workerErrCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
msg := <-workerCh
|
||||
|
||||
action := msg.GetPortForwardAction()
|
||||
if action == nil {
|
||||
workerErrCh <- errors.New("expected port forward action")
|
||||
return
|
||||
}
|
||||
|
||||
tunnelConn, workerConn := net.Pipe()
|
||||
proxyCtx, err := controller.connRendezvous.Respond(action.Session, rendezvous.ResultWithErrorMessage[net.Conn]{
|
||||
Result: tunnelConn,
|
||||
})
|
||||
if err != nil {
|
||||
_ = tunnelConn.Close()
|
||||
_ = workerConn.Close()
|
||||
workerErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-proxyCtx.Done()
|
||||
_ = workerConn.Close()
|
||||
}()
|
||||
|
||||
proxyCtxCh <- proxyCtx
|
||||
}()
|
||||
|
||||
tunnel, responderImpl := controller.establishExecuteSSHTunnel(context.Background(), &v1.VM{
|
||||
Worker: "worker-1",
|
||||
UID: "vm-uid",
|
||||
})
|
||||
require.Nil(t, responderImpl)
|
||||
|
||||
var proxyCtx context.Context
|
||||
|
||||
select {
|
||||
case err := <-workerErrCh:
|
||||
require.NoError(t, err)
|
||||
case proxyCtx = <-proxyCtxCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for tunnel rendezvous response")
|
||||
}
|
||||
|
||||
require.NotNil(t, proxyCtx)
|
||||
|
||||
select {
|
||||
case <-proxyCtx.Done():
|
||||
t.Fatal("proxy context canceled before tunnel close")
|
||||
default:
|
||||
}
|
||||
|
||||
require.NoError(t, tunnel.Close())
|
||||
|
||||
select {
|
||||
case <-proxyCtx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("proxy context was not canceled after tunnel close")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-workerErrCh:
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue