orchard/internal/command/ssh/vm.go

205 lines
4.6 KiB
Go

package ssh
import (
"context"
"errors"
"fmt"
"github.com/cirruslabs/orchard/pkg/client"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
"io"
"net"
"os"
"time"
)
var ErrFailed = errors.New("ssh command failed")
var username string
var password string
var wait uint16
func newSSHVMCommand() *cobra.Command {
command := &cobra.Command{
Use: "vm VM_NAME [COMMAND]",
Short: "SSH into the VM",
Args: cobra.RangeArgs(1, 2),
RunE: runSSHVM,
}
command.Flags().StringVarP(&username, "username", "u", "",
"SSH username")
command.Flags().StringVarP(&password, "password", "p", "",
"SSH password")
command.Flags().Uint16VarP(&wait, "wait", "t", 60,
"Amount of seconds to wait for the VM to start running if it's not running already")
return command
}
func runSSHVM(cmd *cobra.Command, args []string) error {
// Required NAME argument
name := args[0]
// Optional [COMMAND] argument
var command string
if len(args) > 1 {
command = args[1]
}
client, err := client.New()
if err != nil {
return err
}
wsConn, err := client.VMs().PortForward(cmd.Context(), name, 22, wait)
if err != nil {
return fmt.Errorf("%w: failed to setup port-forwarding to the VM %q: %v", ErrFailed, name, err)
}
defer wsConn.Close()
username, password = ChooseUsernameAndPassword(cmd.Context(), client, name, username, password)
sshConfig := &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
User: username,
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
}
sshConn, chans, reqs, err := ssh.NewClientConn(wsConn, "", sshConfig)
if err != nil {
return fmt.Errorf("%w: failed to establish an SSH connection: %v", ErrFailed, err)
}
sshClient := ssh.NewClient(sshConn, chans, reqs)
sshSess, err := sshClient.NewSession()
if err != nil {
return fmt.Errorf("%w: failed to open an SSH session: %v", ErrFailed, err)
}
defer func() {
_ = sshSess.Close()
}()
if command != "" {
sshSess.Stdout = os.Stdout
sshSess.Stderr = os.Stderr
sshSess.Stdin = os.Stdin
return sshSess.Run(command)
}
// Switch controlling terminal into raw mode,
// otherwise ANSI escape sequences that allow
// for cursor control and more wouldn't work
stdinFD := int(os.Stdin.Fd())
state, err := term.MakeRaw(stdinFD)
if err != nil {
return fmt.Errorf("%w: failed to switch controlling terminal into raw mode: %v", ErrFailed, err)
}
defer func() {
_ = term.Restore(stdinFD, state)
}()
stdoutFD := int(os.Stdout.Fd())
width, height, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
return err
}
if err := sshSess.RequestPty("xterm-256color", height, width, ssh.TerminalModes{}); err != nil {
return fmt.Errorf("%w: failed to request the PTY from the SSH server: %v", ErrFailed, err)
}
sshSess.Stdout = os.Stdout
sshSess.Stderr = os.Stderr
sshSessStdinPipe, err := sshSess.StdinPipe()
if err != nil {
return err
}
go func() {
_, _ = io.Copy(sshSessStdinPipe, os.Stdin)
_ = sshSessStdinPipe.Close()
}()
// Periodically adjust remote terminal size
go func() {
for {
select {
case <-time.After(time.Second * time.Duration(wait)):
// Proceed with adjusting the remote terminal size
case <-cmd.Context().Done():
return
}
newWidth, newHeight, err := term.GetSize(stdoutFD)
if err != nil {
return
}
if height == newHeight && width == newWidth {
continue
}
if err := sshSess.WindowChange(newHeight, newWidth); err != nil {
continue
}
height = newHeight
width = newWidth
}
}()
if err := sshSess.Shell(); err != nil {
return err
}
sshErr := make(chan error)
go func() {
sshErr <- sshSess.Wait()
}()
select {
case err := <-sshErr:
return err
case <-cmd.Context().Done():
return cmd.Context().Err()
}
}
func ChooseUsernameAndPassword(
ctx context.Context,
client *client.Client,
vmName string,
usernameFromUser string,
passwordFromUser string,
) (string, string) {
// User settings override everything
if usernameFromUser != "" || passwordFromUser != "" {
return usernameFromUser, passwordFromUser
}
// Try to get the credentials from the VM's object stored on controller
vm, err := client.VMs().Get(ctx, vmName)
if err == nil && vm.Username != "" && vm.Password != "" {
return vm.Username, vm.Password
} else if err != nil {
fmt.Fprintf(os.Stderr, "failed to retrieve VM %s's credentials from the API server: %v\n",
vmName, err)
}
// Fall back
_, _ = fmt.Fprintf(os.Stderr, "no credentials specified or found, "+
"trying default admin:admin credentials...\n")
return "admin", "admin"
}