Merge bab12e74ed into 88506b1adb
This commit is contained in:
commit
1354d549e1
|
|
@ -435,7 +435,7 @@ paths:
|
|||
minLength: 1
|
||||
required: true
|
||||
- in: query
|
||||
name: stdin
|
||||
name: interactive
|
||||
description: |
|
||||
Whether to allocate an interactive standard input for the command
|
||||
|
||||
|
|
@ -446,6 +446,30 @@ paths:
|
|||
type: boolean
|
||||
default: false
|
||||
required: false
|
||||
- in: query
|
||||
name: stdin
|
||||
deprecated: true
|
||||
description: |
|
||||
Deprecated alias for `interactive`.
|
||||
|
||||
If both `interactive` and `stdin` are provided, their values must match.
|
||||
schema:
|
||||
type: boolean
|
||||
default: false
|
||||
required: false
|
||||
- in: query
|
||||
name: env
|
||||
description: |
|
||||
Environment variables to expose to the command.
|
||||
|
||||
Use deep object query syntax, for example `env[FOO]=bar&env[BAZ]=qux`.
|
||||
style: deepObject
|
||||
explode: true
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
required: false
|
||||
- in: query
|
||||
name: wait
|
||||
description: Duration in seconds for the VM to become available if it's not available already
|
||||
|
|
|
|||
|
|
@ -35,7 +35,15 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
NewErrorResponse("\"command\" parameter cannot be empty"))
|
||||
}
|
||||
|
||||
stdin := ctx.Query("stdin") == "true"
|
||||
interactive, err := parseExecInteractive(ctx)
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusBadRequest, NewErrorResponse("%v", err))
|
||||
}
|
||||
|
||||
command, err = sshexec.CommandWithEnv(command, ctx.QueryMap("env"))
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusBadRequest, NewErrorResponse("%v", err))
|
||||
}
|
||||
|
||||
waitRaw := ctx.DefaultQuery("wait", "10")
|
||||
wait, err := strconv.ParseUint(waitRaw, 10, 16)
|
||||
|
|
@ -68,7 +76,7 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
defer portForwardConn.Close()
|
||||
|
||||
// Establish an SSH connection to a VM
|
||||
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), stdin)
|
||||
exec, err := sshexec.New(portForwardConn, vm.SSHUsername(), vm.SSHPassword(), interactive)
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusServiceUnavailable, NewErrorResponse("failed to establish SSH connection to a VM: %v", err))
|
||||
}
|
||||
|
|
@ -151,6 +159,38 @@ func (controller *Controller) execVM(ctx *gin.Context) responder.Responder {
|
|||
}
|
||||
}
|
||||
|
||||
func parseExecInteractive(ctx *gin.Context) (bool, error) {
|
||||
interactiveRaw, interactivePresent := ctx.GetQuery("interactive")
|
||||
stdinRaw, stdinPresent := ctx.GetQuery("stdin")
|
||||
|
||||
interactive := false
|
||||
if interactivePresent {
|
||||
parsed, err := strconv.ParseBool(interactiveRaw)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("\"interactive\" parameter must be a boolean")
|
||||
}
|
||||
|
||||
interactive = parsed
|
||||
}
|
||||
|
||||
if stdinPresent {
|
||||
stdin, err := strconv.ParseBool(stdinRaw)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("\"stdin\" parameter must be a boolean")
|
||||
}
|
||||
|
||||
if interactivePresent && stdin != interactive {
|
||||
return false, fmt.Errorf("\"interactive\" and \"stdin\" parameters cannot conflict")
|
||||
}
|
||||
|
||||
if !interactivePresent {
|
||||
interactive = stdin
|
||||
}
|
||||
}
|
||||
|
||||
return interactive, nil
|
||||
}
|
||||
|
||||
func (controller *Controller) readFrames(
|
||||
ctx context.Context,
|
||||
wsConn *websocket.Conn,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
//nolint:testpackage // we need to test unexported exec helpers
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseExecInteractive(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
query string
|
||||
interactive bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "default false",
|
||||
interactive: false,
|
||||
},
|
||||
{
|
||||
name: "interactive true",
|
||||
query: "interactive=true",
|
||||
interactive: true,
|
||||
},
|
||||
{
|
||||
name: "stdin alias true",
|
||||
query: "stdin=true",
|
||||
interactive: true,
|
||||
},
|
||||
{
|
||||
name: "matching values accepted",
|
||||
query: "interactive=true&stdin=true",
|
||||
interactive: true,
|
||||
},
|
||||
{
|
||||
name: "conflicting values rejected",
|
||||
query: "interactive=true&stdin=false",
|
||||
errContains: "cannot conflict",
|
||||
},
|
||||
{
|
||||
name: "invalid interactive rejected",
|
||||
query: "interactive=maybe",
|
||||
errContains: "interactive",
|
||||
},
|
||||
{
|
||||
name: "invalid stdin rejected",
|
||||
query: "stdin=maybe",
|
||||
errContains: "stdin",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
interactive, err := parseExecInteractive(execQueryContext(test.query))
|
||||
if test.errContains != "" {
|
||||
require.ErrorContains(t, err, test.errContains)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.interactive, interactive)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func execQueryContext(query string) *gin.Context {
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
|
@ -6,13 +6,18 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/execstream"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var envNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
|
||||
|
||||
type Exec struct {
|
||||
sshClient *ssh.Client
|
||||
sshSession *ssh.Session
|
||||
|
|
@ -87,6 +92,43 @@ func (exec *Exec) Stdin() io.WriteCloser {
|
|||
return exec.stdin
|
||||
}
|
||||
|
||||
func CommandWithEnv(command string, env map[string]string) (string, error) {
|
||||
if len(env) == 0 {
|
||||
return command, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(env))
|
||||
for key, value := range env {
|
||||
if !envNamePattern.MatchString(key) {
|
||||
return "", fmt.Errorf("invalid environment variable name %q", key)
|
||||
}
|
||||
|
||||
if strings.ContainsRune(value, '\x00') {
|
||||
return "", fmt.Errorf("environment variable %q contains NUL byte", key)
|
||||
}
|
||||
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
|
||||
var builder strings.Builder
|
||||
for _, key := range keys {
|
||||
builder.WriteString("export ")
|
||||
builder.WriteString(key)
|
||||
builder.WriteByte('=')
|
||||
builder.WriteString(shellQuote(env[key]))
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
builder.WriteString(command)
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func shellQuote(value string) string {
|
||||
return "'" + strings.ReplaceAll(value, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
func (exec *Exec) Run(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package sshexec_test
|
|||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -24,3 +25,39 @@ func TestContextCancellationViaNetConnClose(t *testing.T) {
|
|||
_, err := sshexec.New(clientConn, "doesn't", "matter", false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCommandWithEnvNoEnvLeavesCommandUnchanged(t *testing.T) {
|
||||
command, err := sshexec.CommandWithEnv("echo hello", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "echo hello", command)
|
||||
}
|
||||
|
||||
func TestCommandWithEnvSortsAndQuotes(t *testing.T) {
|
||||
command, err := sshexec.CommandWithEnv("printf '%s|%s|%s' \"$GREETING\" \"$NAME\" \"$MULTILINE\"", map[string]string{
|
||||
"NAME": "O'Reilly",
|
||||
"GREETING": "hello $USER",
|
||||
"MULTILINE": "line 1\nline 2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, strings.Join([]string{
|
||||
"export GREETING='hello $USER'",
|
||||
"export MULTILINE='line 1",
|
||||
"line 2'",
|
||||
"export NAME='O'\\''Reilly'",
|
||||
"printf '%s|%s|%s' \"$GREETING\" \"$NAME\" \"$MULTILINE\"",
|
||||
}, "\n"), command)
|
||||
}
|
||||
|
||||
func TestCommandWithEnvRejectsInvalidName(t *testing.T) {
|
||||
_, err := sshexec.CommandWithEnv("echo hello", map[string]string{
|
||||
"1INVALID": "value",
|
||||
})
|
||||
require.ErrorContains(t, err, "invalid environment variable name")
|
||||
}
|
||||
|
||||
func TestCommandWithEnvRejectsNULValue(t *testing.T) {
|
||||
_, err := sshexec.CommandWithEnv("echo hello", map[string]string{
|
||||
"VALID": "bad\x00value",
|
||||
})
|
||||
require.ErrorContains(t, err, "contains NUL byte")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,6 +79,48 @@ func TestVMExecWithStdin(t *testing.T) {
|
|||
require.Equal(t, websocket.StatusNormalClosure, closeError.Code)
|
||||
}
|
||||
|
||||
func TestVMExecWithEnv(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
|
||||
script := "sh -c 'printf \"%s|%s|%s\" \"$GREETING\" \"$QUOTE\" \"$EMPTY\"'"
|
||||
|
||||
wsConn, err := devClient.VMs().ExecWithOptions(t.Context(), vmName, script, client.ExecOptions{
|
||||
Env: map[string]string{
|
||||
"EMPTY": "",
|
||||
"GREETING": "Hello, World!",
|
||||
"QUOTE": "O'Reilly",
|
||||
},
|
||||
WaitSeconds: 30,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer wsConn.CloseNow()
|
||||
|
||||
var stdout bytes.Buffer
|
||||
var exitFrame *execstream.Frame
|
||||
|
||||
for exitFrame == nil {
|
||||
frame := readFrame(t, wsConn)
|
||||
|
||||
switch frame.Type {
|
||||
case execstream.FrameTypeStdout:
|
||||
stdout.Write(frame.Data)
|
||||
case execstream.FrameTypeExit:
|
||||
exitFrame = frame
|
||||
default:
|
||||
t.Fatalf("unexpected frame type %q", frame.Type)
|
||||
}
|
||||
}
|
||||
|
||||
require.EqualValues(t, 0, exitFrame.Exit.Code)
|
||||
require.Equal(t, "Hello, World!|O'Reilly|", stdout.String())
|
||||
|
||||
// Ensure that Orchard Controller gracefully terminates the WebSocket connection
|
||||
_, _, err = wsConn.Read(t.Context())
|
||||
var closeError websocket.CloseError
|
||||
require.ErrorAs(t, err, &closeError)
|
||||
require.Equal(t, websocket.StatusNormalClosure, closeError.Code)
|
||||
}
|
||||
|
||||
func TestVMExecScript(t *testing.T) {
|
||||
devClient, vmName := prepareForExec(t)
|
||||
|
||||
|
|
|
|||
|
|
@ -157,6 +157,12 @@ func (service *VMsService) PortForward(
|
|||
})
|
||||
}
|
||||
|
||||
type ExecOptions struct {
|
||||
Interactive bool
|
||||
WaitSeconds uint16
|
||||
Env map[string]string
|
||||
}
|
||||
|
||||
func (service *VMsService) Exec(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
|
|
@ -164,12 +170,30 @@ func (service *VMsService) Exec(
|
|||
stdin bool,
|
||||
waitSeconds uint16,
|
||||
) (*websocket.Conn, error) {
|
||||
return service.ExecWithOptions(ctx, name, command, ExecOptions{
|
||||
Interactive: stdin,
|
||||
WaitSeconds: waitSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
func (service *VMsService) ExecWithOptions(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
command string,
|
||||
options ExecOptions,
|
||||
) (*websocket.Conn, error) {
|
||||
params := map[string]string{
|
||||
"command": command,
|
||||
"interactive": strconv.FormatBool(options.Interactive),
|
||||
"wait": strconv.FormatUint(uint64(options.WaitSeconds), 10),
|
||||
}
|
||||
|
||||
for key, value := range options.Env {
|
||||
params[fmt.Sprintf("env[%s]", key)] = value
|
||||
}
|
||||
|
||||
return service.client.wsRequestRaw(ctx, fmt.Sprintf("vms/%s/exec", url.PathEscape(name)),
|
||||
map[string]string{
|
||||
"command": command,
|
||||
"stdin": strconv.FormatBool(stdin),
|
||||
"wait": strconv.FormatUint(uint64(waitSeconds), 10),
|
||||
})
|
||||
params)
|
||||
}
|
||||
|
||||
func (service *VMsService) IP(ctx context.Context, name string, waitSeconds uint16) (string, error) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue