diff --git a/go.mod b/go.mod index 9971808..8030591 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/manifoldco/promptui v0.9.0 github.com/mitchellh/go-grpc-net-conn v0.0.0-20200427190222-eb030e4876f0 github.com/penglongli/gin-metrics v0.1.10 + github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.15.0 github.com/samber/lo v1.38.1 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -34,6 +35,7 @@ require ( gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 howett.net/plist v1.0.0 + nhooyr.io/websocket v1.8.7 ) require ( @@ -74,7 +76,7 @@ require ( github.com/inconshreveable/mousetrap v1.0.1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.13.6 // indirect + github.com/klauspost/compress v1.16.7 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/mailru/easyjson v0.7.7 // indirect @@ -87,7 +89,6 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/common v0.42.0 // indirect diff --git a/go.sum b/go.sum index cf5ba91..26a759a 100644 --- a/go.sum +++ b/go.sum @@ -108,6 +108,7 @@ github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYF github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gin-gonic/gin v1.7.4/go.mod h1:jD2toBW3GZUr5UMcdrwQA10I7RuaFOl/SGeDjXkfUtY= github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8= github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k= @@ -155,6 +156,7 @@ github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/Nu github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU= github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s= @@ -185,6 +187,12 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA= github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= @@ -259,6 +267,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gosuri/uitable v0.0.4 h1:IG2xLKRvErL3uhY6e1BylFzG+aJiwQviDDTfOKeKTpY= github.com/gosuri/uitable v0.0.4/go.mod h1:tKR86bXuXPZazfOTG1FIzvjIdXzd0mo4Vtn16vt0PJo= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= @@ -293,9 +303,11 @@ github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaR github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -800,6 +812,8 @@ honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9 honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= +nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= +nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= diff --git a/internal/command/ssh/vm.go b/internal/command/ssh/vm.go index 34f1b72..41776fb 100644 --- a/internal/command/ssh/vm.go +++ b/internal/command/ssh/vm.go @@ -56,9 +56,7 @@ func runSSHVM(cmd *cobra.Command, args []string) error { wsConn, err := client.VMs().PortForward(cmd.Context(), name, 22, wait) if err != nil { - fmt.Printf("failed to forward an SSH port to VM %s: %v\n", name, err) - - return err + return fmt.Errorf("%w: failed to setup port-forwarding to the VM %q: %v", ErrFailed, name, err) } defer wsConn.Close() @@ -78,9 +76,6 @@ func runSSHVM(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("%w: failed to establish an SSH connection: %v", ErrFailed, err) } - defer func() { - _ = sshConn.Close() - }() sshClient := ssh.NewClient(sshConn, chans, reqs) @@ -188,16 +183,16 @@ func ChooseUsernameAndPassword( // Try to get the credentials from the VM's object stored on controller vm, err := client.VMs().Get(ctx, vmName) - if err != nil { - fmt.Printf("failed to retrieve VM %s's credentials from the API server: %v\n", vmName, err) - } - - if vm.Username != "" && vm.Password != "" { + if err == nil && vm.Username != "" && vm.Password != "" { return vm.Username, vm.Password + } else if err != nil { + fmt.Fprintf(os.Stderr, "failed to retrieve VM %s's credentials from the API server: %v\n", + vmName, err) } // Fall back - _, _ = fmt.Fprintf(os.Stderr, "no credentials specified or found, trying default admin:admin credentials...") + _, _ = fmt.Fprintf(os.Stderr, "no credentials specified or found, "+ + "trying default admin:admin credentials...\n") return "admin", "admin" } diff --git a/internal/command/vnc/vm.go b/internal/command/vnc/vm.go index 019b784..c603fa1 100644 --- a/internal/command/vnc/vm.go +++ b/internal/command/vnc/vm.go @@ -72,6 +72,7 @@ func runVNCVM(cmd *cobra.Command, args []string) (err error) { return } + defer wsConn.Close() if err := proxy.Connections(wsConn, conn); err != nil { fmt.Printf("failed to forward port: %v\n", err) diff --git a/internal/controller/api_vms_portforward.go b/internal/controller/api_vms_portforward.go index 2f3fff9..3d48637 100644 --- a/internal/controller/api_vms_portforward.go +++ b/internal/controller/api_vms_portforward.go @@ -9,8 +9,8 @@ import ( "github.com/cirruslabs/orchard/rpc" "github.com/gin-gonic/gin" "github.com/google/uuid" - "golang.org/x/net/websocket" "net/http" + "nhooyr.io/websocket" "strconv" "time" ) @@ -95,11 +95,26 @@ func (controller *Controller) portForwardVM(ctx *gin.Context) responder.Responde // worker will asynchronously start port-forwarding so we wait select { case fromWorkerConnection := <-boomerangConnCh: - websocket.Handler(func(wsConn *websocket.Conn) { - if err := proxy.Connections(wsConn, fromWorkerConnection); err != nil { - controller.logger.Warnf("failed to port-forward: %v", err) - } - }).ServeHTTP(ctx.Writer, ctx.Request) + wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + if err != nil { + return responder.Error(err) + } + + expectedMsgType := websocket.MessageBinary + + // Backwards compatibility with older Orchard clients + // using "golang.org/x/net/websocket" package + if ctx.Request.Header.Get("User-Agent") == "" { + expectedMsgType = websocket.MessageText + } + + wsConnAsNetConn := websocket.NetConn(ctx, wsConn, expectedMsgType) + + if err := proxy.Connections(wsConnAsNetConn, fromWorkerConnection); err != nil { + controller.logger.Warnf("failed to port-forward: %v", err) + } return responder.Empty() case <-ctx.Done(): diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 3dc46b1..040c9a5 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -39,7 +39,7 @@ func Connections(left net.Conn, right net.Conn) (finalErr error) { recordErr(<-rightErrCh) } - if strings.Contains(finalErr.Error(), "use of closed network connection") { + if finalErr != nil && strings.Contains(finalErr.Error(), "use of closed network connection") { finalErr = nil } diff --git a/internal/responder/error.go b/internal/responder/error.go index 34fd3e3..54e354b 100644 --- a/internal/responder/error.go +++ b/internal/responder/error.go @@ -20,9 +20,12 @@ func Error(err error) Responder { func (responder *ErrorResponder) Respond(c *gin.Context) { var code = http.StatusInternalServerError + if errors.Is(responder.err, storepkg.ErrNotFound) { code = http.StatusNotFound + } else { + _ = c.Error(responder.err) } - _ = c.Error(responder.err) + c.Status(code) } diff --git a/pkg/client/client.go b/pkg/client/client.go index 60a4ab4..f338940 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -11,14 +11,16 @@ import ( "fmt" "github.com/cirruslabs/orchard/internal/config" "github.com/cirruslabs/orchard/internal/netconstants" + "github.com/cirruslabs/orchard/internal/version" "github.com/cirruslabs/orchard/rpc" - "golang.org/x/net/websocket" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "io" + "net" "net/http" "net/url" + "nhooyr.io/websocket" ) var ( @@ -184,9 +186,7 @@ func (client *Client) request( return fmt.Errorf("%w instantiate a request: %v", ErrFailed, err) } - if client.serviceAccountName != "" && client.serviceAccountToken != "" { - request.SetBasicAuth(client.serviceAccountName, client.serviceAccountToken) - } + client.modifyHeader(request.Header) response, err := client.httpClient.Do(request) if err != nil { @@ -238,10 +238,10 @@ func detailsFromErrorResponseBody(body io.Reader) string { } func (client *Client) wsRequest( - _ context.Context, + ctx context.Context, path string, params map[string]string, -) (*websocket.Conn, error) { +) (net.Conn, error) { endpointURL, err := client.parsePath(path) if err != nil { return nil, err @@ -260,20 +260,27 @@ func (client *Client) wsRequest( } endpointURL.RawQuery = values.Encode() - config, err := websocket.NewConfig(endpointURL.String(), "http://127.0.0.1/") + dialOptions := &websocket.DialOptions{ + HTTPClient: client.httpClient, + HTTPHeader: make(http.Header), + } + + client.modifyHeader(dialOptions.HTTPHeader) + + conn, resp, err := websocket.Dial(ctx, endpointURL.String(), dialOptions) if err != nil { - return nil, fmt.Errorf("%w to create WebSocket configuration: %v", ErrFailed, err) + if resp != nil { + _ = resp.Body.Close() + } + + if resp.StatusCode == http.StatusNotFound { + err = fmt.Errorf("%w (are you sure this VM exists on the controller?)", err) + } + + return nil, err } - if client.serviceAccountName != "" && client.serviceAccountToken != "" { - authPlain := fmt.Sprintf("%s:%s", client.serviceAccountName, client.serviceAccountToken) - authEncoded := base64.StdEncoding.EncodeToString([]byte(authPlain)) - config.Header.Add("Authorization", fmt.Sprintf("Basic %s", authEncoded)) - } - - config.TlsConfig = client.tlsConfig - - return websocket.DialConfig(config) + return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil } func (client *Client) parsePath(path string) (*url.URL, error) { @@ -291,6 +298,16 @@ func (client *Client) parsePath(path string) (*url.URL, error) { }, nil } +func (client *Client) modifyHeader(header http.Header) { + header.Set("User-Agent", fmt.Sprintf("Orchard/%s", version.FullVersion)) + + if client.serviceAccountName != "" && client.serviceAccountToken != "" { + authPlain := fmt.Sprintf("%s:%s", client.serviceAccountName, client.serviceAccountToken) + authEncoded := base64.StdEncoding.EncodeToString([]byte(authPlain)) + header.Set("Authorization", fmt.Sprintf("Basic %s", authEncoded)) + } +} + func (client *Client) Check(ctx context.Context) error { return client.request(ctx, http.MethodGet, "/", nil, nil, nil) } diff --git a/pkg/client/vms.go b/pkg/client/vms.go index 9fefc67..fc3ef74 100644 --- a/pkg/client/vms.go +++ b/pkg/client/vms.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "github.com/cirruslabs/orchard/pkg/resource/v1" - "golang.org/x/net/websocket" + "net" "net/http" "strconv" ) @@ -92,7 +92,7 @@ func (service *VMsService) PortForward( name string, port uint16, waitSeconds uint16, -) (*websocket.Conn, error) { +) (net.Conn, error) { return service.client.wsRequest(ctx, fmt.Sprintf("vms/%s/port-forward", name), map[string]string{ "port": strconv.FormatUint(uint64(port), 10),