diff --git a/api/openapi.yaml b/api/openapi.yaml index 29823a2..08b7026 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -435,7 +435,7 @@ paths: minLength: 1 required: true - in: query - name: stdin + name: interactive description: | Whether to allocate an interactive standard input for the command @@ -446,6 +446,30 @@ paths: type: boolean default: false required: false + - in: query + name: stdin + deprecated: true + description: | + Deprecated alias for `interactive`. + + If both `interactive` and `stdin` are provided, their values must match. + schema: + type: boolean + default: false + required: false + - in: query + name: env + description: | + Environment variables to expose to the command. + + Use deep object query syntax, for example `env[FOO]=bar&env[BAZ]=qux`. + style: deepObject + explode: true + schema: + type: object + additionalProperties: + type: string + required: false - in: query name: wait description: Duration in seconds for the VM to become available if it's not available already diff --git a/internal/controller/api_vms_exec.go b/internal/controller/api_vms_exec.go index de384ca..2dcd87d 100644 --- a/internal/controller/api_vms_exec.go +++ b/internal/controller/api_vms_exec.go @@ -35,7 +35,15 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { NewErrorResponse("\"command\" parameter cannot be empty")) } - stdin := ctx.Query("stdin") == "true" + interactive, err := parseExecInteractive(ctx) + if err != nil { + return responder.JSON(http.StatusBadRequest, NewErrorResponse("%v", err)) + } + + command, err = sshexec.CommandWithEnv(command, ctx.QueryMap("env")) + if err != nil { + return responder.JSON(http.StatusBadRequest, NewErrorResponse("%v", err)) + } waitRaw := ctx.DefaultQuery("wait", "10") wait, err := strconv.ParseUint(waitRaw, 10, 16) @@ -68,7 +76,7 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { defer portForwardConn.Close() // Establish an SSH connection to a VM - exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin) + exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), interactive) if err != nil { return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err)) } @@ -151,6 +159,38 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { } } +func parseExecInteractive(ctx *gin.Context) (bool, error) { + interactiveRaw, interactivePresent := ctx.GetQuery("interactive") + stdinRaw, stdinPresent := ctx.GetQuery("stdin") + + interactive := false + if interactivePresent { + parsed, err := strconv.ParseBool(interactiveRaw) + if err != nil { + return false, fmt.Errorf("\"interactive\" parameter must be a boolean") + } + + interactive = parsed + } + + if stdinPresent { + stdin, err := strconv.ParseBool(stdinRaw) + if err != nil { + return false, fmt.Errorf("\"stdin\" parameter must be a boolean") + } + + if interactivePresent && stdin != interactive { + return false, fmt.Errorf("\"interactive\" and \"stdin\" parameters cannot conflict") + } + + if !interactivePresent { + interactive = stdin + } + } + + return interactive, nil +} + func (controller *Controller) readFrames( ctx context.Context, wsConn *websocket.Conn, diff --git a/internal/controller/api_vms_exec_test.go b/internal/controller/api_vms_exec_test.go new file mode 100644 index 0000000..e11ad12 --- /dev/null +++ b/internal/controller/api_vms_exec_test.go @@ -0,0 +1,74 @@ +//nolint:testpackage // we need to test unexported exec helpers +package controller + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestParseExecInteractive(t *testing.T) { + for _, test := range []struct { + name string + query string + interactive bool + errContains string + }{ + { + name: "default false", + interactive: false, + }, + { + name: "interactive true", + query: "interactive=true", + interactive: true, + }, + { + name: "stdin alias true", + query: "stdin=true", + interactive: true, + }, + { + name: "matching values accepted", + query: "interactive=true&stdin=true", + interactive: true, + }, + { + name: "conflicting values rejected", + query: "interactive=true&stdin=false", + errContains: "cannot conflict", + }, + { + name: "invalid interactive rejected", + query: "interactive=maybe", + errContains: "interactive", + }, + { + name: "invalid stdin rejected", + query: "stdin=maybe", + errContains: "stdin", + }, + } { + t.Run(test.name, func(t *testing.T) { + interactive, err := parseExecInteractive(execQueryContext(test.query)) + if test.errContains != "" { + require.ErrorContains(t, err, test.errContains) + + return + } + + require.NoError(t, err) + require.Equal(t, test.interactive, interactive) + }) + } +} + +func execQueryContext(query string) *gin.Context { + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ctx.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil) + + return ctx +} diff --git a/internal/controller/sshexec/sshexec.go b/internal/controller/sshexec/sshexec.go index 41c8f70..f2156e3 100644 --- a/internal/controller/sshexec/sshexec.go +++ b/internal/controller/sshexec/sshexec.go @@ -6,13 +6,18 @@ import ( "fmt" "io" "net" + "regexp" "slices" + "sort" + "strings" "github.com/cirruslabs/orchard/internal/execstream" "golang.org/x/crypto/ssh" "golang.org/x/sync/errgroup" ) +var envNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + type Exec struct { sshClient *ssh.Client sshSession *ssh.Session @@ -87,6 +92,43 @@ func (exec *Exec) Stdin() io.WriteCloser { return exec.stdin } +func CommandWithEnv(command string, env map[string]string) (string, error) { + if len(env) == 0 { + return command, nil + } + + keys := make([]string, 0, len(env)) + for key, value := range env { + if !envNamePattern.MatchString(key) { + return "", fmt.Errorf("invalid environment variable name %q", key) + } + + if strings.ContainsRune(value, '\x00') { + return "", fmt.Errorf("environment variable %q contains NUL byte", key) + } + + keys = append(keys, key) + } + + sort.Strings(keys) + + var builder strings.Builder + for _, key := range keys { + builder.WriteString("export ") + builder.WriteString(key) + builder.WriteByte('=') + builder.WriteString(shellQuote(env[key])) + builder.WriteByte('\n') + } + builder.WriteString(command) + + return builder.String(), nil +} + +func shellQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", "'\\''") + "'" +} + func (exec *Exec) Run( ctx context.Context, command string, diff --git a/internal/controller/sshexec/sshexec_test.go b/internal/controller/sshexec/sshexec_test.go index bd360ea..40ff163 100644 --- a/internal/controller/sshexec/sshexec_test.go +++ b/internal/controller/sshexec/sshexec_test.go @@ -2,6 +2,7 @@ package sshexec_test import ( "net" + "strings" "testing" "time" @@ -24,3 +25,39 @@ func TestContextCancellationViaNetConnClose(t *testing.T) { _, err := sshexec.New(clientConn, "doesn't", "matter", false) require.Error(t, err) } + +func TestCommandWithEnvNoEnvLeavesCommandUnchanged(t *testing.T) { + command, err := sshexec.CommandWithEnv("echo hello", nil) + require.NoError(t, err) + require.Equal(t, "echo hello", command) +} + +func TestCommandWithEnvSortsAndQuotes(t *testing.T) { + command, err := sshexec.CommandWithEnv("printf '%s|%s|%s' \"$GREETING\" \"$NAME\" \"$MULTILINE\"", map[string]string{ + "NAME": "O'Reilly", + "GREETING": "hello $USER", + "MULTILINE": "line 1\nline 2", + }) + require.NoError(t, err) + require.Equal(t, strings.Join([]string{ + "export GREETING='hello $USER'", + "export MULTILINE='line 1", + "line 2'", + "export NAME='O'\\''Reilly'", + "printf '%s|%s|%s' \"$GREETING\" \"$NAME\" \"$MULTILINE\"", + }, "\n"), command) +} + +func TestCommandWithEnvRejectsInvalidName(t *testing.T) { + _, err := sshexec.CommandWithEnv("echo hello", map[string]string{ + "1INVALID": "value", + }) + require.ErrorContains(t, err, "invalid environment variable name") +} + +func TestCommandWithEnvRejectsNULValue(t *testing.T) { + _, err := sshexec.CommandWithEnv("echo hello", map[string]string{ + "VALID": "bad\x00value", + }) + require.ErrorContains(t, err, "contains NUL byte") +} diff --git a/internal/tests/exec_test.go b/internal/tests/exec_test.go index 227f66d..bcf03dc 100644 --- a/internal/tests/exec_test.go +++ b/internal/tests/exec_test.go @@ -79,6 +79,48 @@ func TestVMExecWithStdin(t *testing.T) { require.Equal(t, websocket.StatusNormalClosure, closeError.Code) } +func TestVMExecWithEnv(t *testing.T) { + devClient, vmName := prepareForExec(t) + + script := "sh -c 'printf \"%s|%s|%s\" \"$GREETING\" \"$QUOTE\" \"$EMPTY\"'" + + wsConn, err := devClient.VMs().ExecWithOptions(t.Context(), vmName, script, client.ExecOptions{ + Env: map[string]string{ + "EMPTY": "", + "GREETING": "Hello, World!", + "QUOTE": "O'Reilly", + }, + WaitSeconds: 30, + }) + require.NoError(t, err) + defer wsConn.CloseNow() + + var stdout bytes.Buffer + var exitFrame *execstream.Frame + + for exitFrame == nil { + frame := readFrame(t, wsConn) + + switch frame.Type { + case execstream.FrameTypeStdout: + stdout.Write(frame.Data) + case execstream.FrameTypeExit: + exitFrame = frame + default: + t.Fatalf("unexpected frame type %q", frame.Type) + } + } + + require.EqualValues(t, 0, exitFrame.Exit.Code) + require.Equal(t, "Hello, World!|O'Reilly|", stdout.String()) + + // Ensure that Orchard Controller gracefully terminates the WebSocket connection + _, _, err = wsConn.Read(t.Context()) + var closeError websocket.CloseError + require.ErrorAs(t, err, &closeError) + require.Equal(t, websocket.StatusNormalClosure, closeError.Code) +} + func TestVMExecScript(t *testing.T) { devClient, vmName := prepareForExec(t) diff --git a/pkg/client/vms.go b/pkg/client/vms.go index 3bcbb6e..9bec1d7 100644 --- a/pkg/client/vms.go +++ b/pkg/client/vms.go @@ -157,6 +157,12 @@ func (service *VMsService) PortForward( }) } +type ExecOptions struct { + Interactive bool + WaitSeconds uint16 + Env map[string]string +} + func (service *VMsService) Exec( ctx context.Context, name string, @@ -164,12 +170,30 @@ func (service *VMsService) Exec( stdin bool, waitSeconds uint16, ) (*websocket.Conn, error) { + return service.ExecWithOptions(ctx, name, command, ExecOptions{ + Interactive: stdin, + WaitSeconds: waitSeconds, + }) +} + +func (service *VMsService) ExecWithOptions( + ctx context.Context, + name string, + command string, + options ExecOptions, +) (*websocket.Conn, error) { + params := map[string]string{ + "command": command, + "interactive": strconv.FormatBool(options.Interactive), + "wait": strconv.FormatUint(uint64(options.WaitSeconds), 10), + } + + for key, value := range options.Env { + params[fmt.Sprintf("env[%s]", key)] = value + } + return service.client.wsRequestRaw(ctx, fmt.Sprintf("vms/%s/exec", url.PathEscape(name)), - map[string]string{ - "command": command, - "stdin": strconv.FormatBool(stdin), - "wait": strconv.FormatUint(uint64(waitSeconds), 10), - }) + params) } func (service *VMsService) IP(ctx context.Context, name string, waitSeconds uint16) (string, error) {