From 077252f6d4b2ff3e3d7023b5c87fa227f2a7df08 Mon Sep 17 00:00:00 2001 From: Nikolay Edigaryev Date: Thu, 23 Jan 2025 18:17:14 +0400 Subject: [PATCH] Prevent goroutine leak when Close()'ing *grpc_net_conn.Conn (#237) --- internal/controller/api_vms_portforward.go | 23 +++++++++++++++++---- internal/netconncancel/netconncancel.go | 24 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 internal/netconncancel/netconncancel.go diff --git a/internal/controller/api_vms_portforward.go b/internal/controller/api_vms_portforward.go index 2516174..6a0b3ee 100644 --- a/internal/controller/api_vms_portforward.go +++ b/internal/controller/api_vms_portforward.go @@ -3,6 +3,7 @@ package controller import ( "context" storepkg "github.com/cirruslabs/orchard/internal/controller/store" + "github.com/cirruslabs/orchard/internal/netconncancel" "github.com/cirruslabs/orchard/internal/proxy" "github.com/cirruslabs/orchard/internal/responder" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" @@ -10,6 +11,8 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "net/http" "nhooyr.io/websocket" "strconv" @@ -58,8 +61,12 @@ func (controller *Controller) portForward( port uint32, ) responder.Responder { // Request and wait for a connection with a worker + rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx) + defer rendezvousCtxCancel() + session := uuid.New().String() - boomerangConnCh, cancel := controller.connRendezvous.Request(ctx, session) + + boomerangConnCh, cancel := controller.connRendezvous.Request(rendezvousCtx, session) defer cancel() // send request to worker to initiate port-forwarding connection back to us @@ -98,13 +105,21 @@ func (controller *Controller) portForward( } wsConnAsNetConn := websocket.NetConn(ctx, wsConn, expectedMsgType) + fromWorkerConnectionWithCancel := netconncancel.New(fromWorkerConnection, rendezvousCtxCancel) - if err := proxy.Connections(wsConnAsNetConn, fromWorkerConnection); err != nil { + if err := proxy.Connections(wsConnAsNetConn, fromWorkerConnectionWithCancel); err != nil { var websocketCloseError websocket.CloseError // Normal closure from the user - if errors.As(err, &websocketCloseError) && - websocketCloseError.Code == websocket.StatusNormalClosure { + if errors.As(err, &websocketCloseError) && websocketCloseError.Code == websocket.StatusNormalClosure { + return responder.Empty() + } + + if errors.Is(err, context.Canceled) { + return responder.Empty() + } + + if status, ok := status.FromError(err); ok && status.Code() == codes.Canceled { return responder.Empty() } diff --git a/internal/netconncancel/netconncancel.go b/internal/netconncancel/netconncancel.go new file mode 100644 index 0000000..4d90a5d --- /dev/null +++ b/internal/netconncancel/netconncancel.go @@ -0,0 +1,24 @@ +package netconncancel + +import ( + "context" + "net" +) + +type NetConnCancel struct { + Cancel context.CancelFunc + net.Conn +} + +func New(netConn net.Conn, cancel context.CancelFunc) *NetConnCancel { + return &NetConnCancel{ + Cancel: cancel, + Conn: netConn, + } +} + +func (ncc *NetConnCancel) Close() error { + ncc.Cancel() + + return ncc.Conn.Close() +}