orchard/internal/controller/sshexec/sshexec.go

374 lines
7.5 KiB
Go

package sshexec
import (
"context"
"errors"
"fmt"
"io"
"net"
"regexp"
"slices"
"sort"
"strings"
"sync"
"github.com/cirruslabs/orchard/internal/execstream"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
var envNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
type Options struct {
Interactive bool
TTY bool
Rows uint32
Cols uint32
Env map[string]string
Workdir string
}
type Client struct {
sshClient *ssh.Client
done chan struct{}
waitErr error
waitMu sync.Mutex
}
type Exec struct {
ownedClient *Client
sshSession *ssh.Session
stdout io.Reader
stderr io.Reader
stdin io.WriteCloser
stdinReader *io.PipeReader
tty bool
}
func NewClient(netConn net.Conn, user string, password string) (*Client, error) {
// Establish an SSH connection
sshConn, sshChans, sshReqs, err := ssh.NewClientConn(netConn, "", &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
User: user,
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
})
if err != nil {
return nil, fmt.Errorf("failed to create an SSH connection: %w", err)
}
client := &Client{
sshClient: ssh.NewClient(sshConn, sshChans, sshReqs),
done: make(chan struct{}),
}
go func() {
err := client.sshClient.Wait()
client.waitMu.Lock()
client.waitErr = err
client.waitMu.Unlock()
close(client.done)
}()
return client, nil
}
func New(netConn net.Conn, user string, password string, options Options) (*Exec, error) {
client, err := NewClient(netConn, user, password)
if err != nil {
return nil, err
}
exec, err := client.NewExec(options)
if err != nil {
_ = client.Close()
return nil, err
}
exec.ownedClient = client
return exec, nil
}
func (client *Client) NewExec(options Options) (*Exec, error) {
// Create a new SSH session
sshSession, err := client.sshClient.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create an SSH session: %w", err)
}
exec := &Exec{
sshSession: sshSession,
tty: options.TTY,
}
if options.Interactive || options.TTY {
stdinReader, stdinWriter := io.Pipe()
sshSession.Stdin = stdinReader
exec.stdinReader = stdinReader
exec.stdin = stdinWriter
}
if options.TTY {
if err := sshSession.RequestPty(
"xterm-256color",
int(options.Rows),
int(options.Cols),
ssh.TerminalModes{},
); err != nil {
_ = sshSession.Close()
return nil, fmt.Errorf("failed to request PTY for the SSH session: %w", err)
}
}
exec.stdout, err = sshSession.StdoutPipe()
if err != nil {
_ = sshSession.Close()
return nil, fmt.Errorf("failed to create standard output pipe "+
"for the SSH session: %w", err)
}
exec.stderr, err = sshSession.StderrPipe()
if err != nil {
_ = sshSession.Close()
return nil, fmt.Errorf("failed to create standard error pipe "+
"for the SSH session: %w", err)
}
return exec, nil
}
func (client *Client) Keepalive() error {
_, _, err := client.sshClient.SendRequest("keepalive@openssh.com", true, nil)
return err
}
func (client *Client) Done() <-chan struct{} {
return client.done
}
func (client *Client) Err() error {
<-client.done
client.waitMu.Lock()
defer client.waitMu.Unlock()
return client.waitErr
}
func (client *Client) Close() error {
return client.sshClient.Close()
}
func (client *Client) ShouldRecreateAfter(err error) bool {
select {
case <-client.done:
return true
default:
}
return errors.Is(err, io.EOF) ||
errors.Is(err, net.ErrClosed) ||
strings.Contains(err.Error(), "use of closed network connection")
}
func (exec *Exec) Stdin() io.WriteCloser {
return exec.stdin
}
func CommandWithOptions(command string, options Options) (string, error) {
if strings.ContainsRune(options.Workdir, '\x00') {
return "", errors.New("working directory contains NUL byte")
}
keys := make([]string, 0, len(options.Env))
for key, value := range options.Env {
if !envNamePattern.MatchString(key) {
return "", fmt.Errorf("invalid environment variable name %q", key)
}
if strings.ContainsRune(value, '\x00') {
return "", fmt.Errorf("environment variable %q contains NUL byte", key)
}
keys = append(keys, key)
}
sort.Strings(keys)
if command == "" {
return command, nil
}
if options.Workdir == "" && len(keys) == 0 {
return command, nil
}
var builder strings.Builder
if options.Workdir != "" {
builder.WriteString("cd ")
builder.WriteString(shellQuote(options.Workdir))
builder.WriteString(" || exit $?\n")
}
for _, key := range keys {
builder.WriteString("export ")
builder.WriteString(key)
builder.WriteByte('=')
builder.WriteString(shellQuote(options.Env[key]))
builder.WriteByte('\n')
}
builder.WriteString(command)
return builder.String(), nil
}
func shellQuote(value string) string {
return "'" + strings.ReplaceAll(value, "'", "'\\''") + "'"
}
func (exec *Exec) Resize(rows uint32, cols uint32) error {
if !exec.tty {
return errors.New("this exec session does not have a TTY")
}
return exec.sshSession.WindowChange(int(rows), int(cols))
}
func (exec *Exec) Run(
ctx context.Context,
command string,
outgoingFrames chan<- *execstream.Frame,
) error {
if exec.stdinReader != nil {
defer func() {
_ = exec.stdinReader.Close()
}()
}
if err := exec.sshSession.Start(command); err != nil {
return fmt.Errorf("failed to start command %q: %w", command, err)
}
// Read bytes from standard output and standard error and stream them as frames
ioGroup, ioGroupCtx := errgroup.WithContext(ctx)
ioGroup.Go(func() error {
return ioStreamReader(ioGroupCtx, exec.stdout, execstream.FrameTypeStdout, outgoingFrames)
})
ioGroup.Go(func() error {
return ioStreamReader(ioGroupCtx, exec.stderr, execstream.FrameTypeStderr, outgoingFrames)
})
sshWaitErrCh := make(chan error, 1)
go func() {
sshWaitErrCh <- exec.sshSession.Wait()
}()
// Wait for SSH command terminate while respecting context
var sshWaitErr error
select {
case sshWaitErr = <-sshWaitErrCh:
// Proceed
case <-ctx.Done():
return ctx.Err()
}
// Wait for the I/O to complete, otherwise we may
// miss some bits of the command's output/error
if err := ioGroup.Wait(); err != nil {
return err
}
// Post an exit event
exitFrame := &execstream.Frame{
Type: execstream.FrameTypeExit,
Exit: &execstream.Exit{
Code: 0,
},
}
if sshWaitErr != nil {
var sshExitError *ssh.ExitError
if errors.As(sshWaitErr, &sshExitError) {
exitFrame.Exit.Code = int32(sshExitError.ExitStatus())
} else {
return fmt.Errorf("failed to execute command %q: %w", command, sshWaitErr)
}
}
select {
case outgoingFrames <- exitFrame:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func ioStreamReader(
ctx context.Context,
r io.Reader,
frameType execstream.FrameType,
ch chan<- *execstream.Frame,
) error {
buf := make([]byte, 4096)
for {
n, err := r.Read(buf)
if n > 0 {
frame := &execstream.Frame{
Type: frameType,
Data: slices.Clone(buf[:n]),
}
select {
case <-ctx.Done():
return ctx.Err()
case ch <- frame:
// Proceed
}
}
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
}
}
func (exec *Exec) Close() error {
if exec.stdin != nil {
_ = exec.stdin.Close()
}
if exec.stdinReader != nil {
_ = exec.stdinReader.Close()
}
if err := exec.sshSession.Close(); err != nil {
if exec.ownedClient != nil {
_ = exec.ownedClient.Close()
}
return err
}
if exec.ownedClient != nil {
return exec.ownedClient.Close()
}
return nil
}