205 lines
4.6 KiB
Go
205 lines
4.6 KiB
Go
package ssh
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/cirruslabs/orchard/pkg/client"
|
|
"github.com/spf13/cobra"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/term"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
)
|
|
|
|
var ErrFailed = errors.New("ssh command failed")
|
|
|
|
var username string
|
|
var password string
|
|
var wait uint16
|
|
|
|
func newSSHVMCommand() *cobra.Command {
|
|
command := &cobra.Command{
|
|
Use: "vm VM_NAME [COMMAND]",
|
|
Short: "SSH into the VM",
|
|
Args: cobra.RangeArgs(1, 2),
|
|
RunE: runSSHVM,
|
|
}
|
|
|
|
command.Flags().StringVarP(&username, "username", "u", "",
|
|
"SSH username")
|
|
command.Flags().StringVarP(&password, "password", "p", "",
|
|
"SSH password")
|
|
command.Flags().Uint16VarP(&wait, "wait", "t", 60,
|
|
"Amount of seconds to wait for the VM to start running if it's not running already")
|
|
|
|
return command
|
|
}
|
|
|
|
func runSSHVM(cmd *cobra.Command, args []string) error {
|
|
// Required NAME argument
|
|
name := args[0]
|
|
|
|
// Optional [COMMAND] argument
|
|
var command string
|
|
|
|
if len(args) > 1 {
|
|
command = args[1]
|
|
}
|
|
|
|
client, err := client.New()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
wsConn, err := client.VMs().PortForward(cmd.Context(), name, 22, wait)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: failed to setup port-forwarding to the VM %q: %v", ErrFailed, name, err)
|
|
}
|
|
defer wsConn.Close()
|
|
|
|
username, password = ChooseUsernameAndPassword(cmd.Context(), client, name, username, password)
|
|
|
|
sshConfig := &ssh.ClientConfig{
|
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
return nil
|
|
},
|
|
User: username,
|
|
Auth: []ssh.AuthMethod{
|
|
ssh.Password(password),
|
|
},
|
|
}
|
|
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(wsConn, "", sshConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: failed to establish an SSH connection: %v", ErrFailed, err)
|
|
}
|
|
|
|
sshClient := ssh.NewClient(sshConn, chans, reqs)
|
|
|
|
sshSess, err := sshClient.NewSession()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: failed to open an SSH session: %v", ErrFailed, err)
|
|
}
|
|
defer func() {
|
|
_ = sshSess.Close()
|
|
}()
|
|
|
|
if command != "" {
|
|
sshSess.Stdout = os.Stdout
|
|
sshSess.Stderr = os.Stderr
|
|
sshSess.Stdin = os.Stdin
|
|
|
|
return sshSess.Run(command)
|
|
}
|
|
|
|
// Switch controlling terminal into raw mode,
|
|
// otherwise ANSI escape sequences that allow
|
|
// for cursor control and more wouldn't work
|
|
stdinFD := int(os.Stdin.Fd())
|
|
state, err := term.MakeRaw(stdinFD)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: failed to switch controlling terminal into raw mode: %v", ErrFailed, err)
|
|
}
|
|
defer func() {
|
|
_ = term.Restore(stdinFD, state)
|
|
}()
|
|
|
|
stdoutFD := int(os.Stdout.Fd())
|
|
width, height, err := term.GetSize(int(os.Stdout.Fd()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := sshSess.RequestPty("xterm-256color", height, width, ssh.TerminalModes{}); err != nil {
|
|
return fmt.Errorf("%w: failed to request the PTY from the SSH server: %v", ErrFailed, err)
|
|
}
|
|
|
|
sshSess.Stdout = os.Stdout
|
|
sshSess.Stderr = os.Stderr
|
|
sshSessStdinPipe, err := sshSess.StdinPipe()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
_, _ = io.Copy(sshSessStdinPipe, os.Stdin)
|
|
_ = sshSessStdinPipe.Close()
|
|
}()
|
|
|
|
// Periodically adjust remote terminal size
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-time.After(time.Second * time.Duration(wait)):
|
|
// Proceed with adjusting the remote terminal size
|
|
case <-cmd.Context().Done():
|
|
return
|
|
}
|
|
|
|
newWidth, newHeight, err := term.GetSize(stdoutFD)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if height == newHeight && width == newWidth {
|
|
continue
|
|
}
|
|
|
|
if err := sshSess.WindowChange(newHeight, newWidth); err != nil {
|
|
continue
|
|
}
|
|
|
|
height = newHeight
|
|
width = newWidth
|
|
}
|
|
}()
|
|
|
|
if err := sshSess.Shell(); err != nil {
|
|
return err
|
|
}
|
|
|
|
sshErr := make(chan error)
|
|
|
|
go func() {
|
|
sshErr <- sshSess.Wait()
|
|
}()
|
|
|
|
select {
|
|
case err := <-sshErr:
|
|
return err
|
|
case <-cmd.Context().Done():
|
|
return cmd.Context().Err()
|
|
}
|
|
}
|
|
|
|
func ChooseUsernameAndPassword(
|
|
ctx context.Context,
|
|
client *client.Client,
|
|
vmName string,
|
|
usernameFromUser string,
|
|
passwordFromUser string,
|
|
) (string, string) {
|
|
// User settings override everything
|
|
if usernameFromUser != "" || passwordFromUser != "" {
|
|
return usernameFromUser, passwordFromUser
|
|
}
|
|
|
|
// Try to get the credentials from the VM's object stored on controller
|
|
vm, err := client.VMs().Get(ctx, vmName)
|
|
if err == nil && vm.Username != "" && vm.Password != "" {
|
|
return vm.Username, vm.Password
|
|
} else if err != nil {
|
|
fmt.Fprintf(os.Stderr, "failed to retrieve VM %s's credentials from the API server: %v\n",
|
|
vmName, err)
|
|
}
|
|
|
|
// Fall back
|
|
_, _ = fmt.Fprintf(os.Stderr, "no credentials specified or found, "+
|
|
"trying default admin:admin credentials...\n")
|
|
|
|
return "admin", "admin"
|
|
}
|