diff --git a/cmd/orchard/main.go b/cmd/orchard/main.go index d2891c1..66a15d8 100644 --- a/cmd/orchard/main.go +++ b/cmd/orchard/main.go @@ -2,8 +2,8 @@ package main import ( "context" + "fmt" "github.com/cirruslabs/orchard/internal/command" - "log" "os" "os/signal" ) @@ -25,6 +25,8 @@ func main() { // Run the command if err := command.NewRootCmd().ExecuteContext(ctx); err != nil { - log.Fatal(err) + _, _ = fmt.Fprintln(os.Stderr, err) + + os.Exit(1) } } diff --git a/go.mod b/go.mod index 550e4ec..fae13cd 100644 --- a/go.mod +++ b/go.mod @@ -26,10 +26,10 @@ require ( github.com/spf13/cobra v1.6.0 github.com/stretchr/testify v1.8.1 go.uber.org/zap v1.24.0 - golang.org/x/crypto v0.21.0 + golang.org/x/crypto v0.23.0 golang.org/x/exp v0.0.0-20230321023759-10a507213a29 golang.org/x/net v0.23.0 - golang.org/x/term v0.18.0 + golang.org/x/term v0.20.0 google.golang.org/grpc v1.56.3 google.golang.org/protobuf v1.30.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -102,8 +102,8 @@ require ( go.uber.org/atomic v1.10.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.15.0 // indirect google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index af83e24..a7f8b54 100644 --- a/go.sum +++ b/go.sum @@ -497,8 +497,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -629,12 +629,12 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -643,8 +643,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/command/controller/init.go b/internal/command/controller/init.go index 60eb89a..2a22117 100644 --- a/internal/command/controller/init.go +++ b/internal/command/controller/init.go @@ -12,7 +12,10 @@ import ( "fmt" "github.com/cirruslabs/orchard/internal/controller" "github.com/cirruslabs/orchard/internal/netconstants" + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" "math/big" + "os" "time" ) @@ -20,25 +23,80 @@ var ErrInitFailed = errors.New("controller initialization failed") var controllerCertPath string var controllerKeyPath string +var sshHostKeyPath string -func FindControllerCertificate(dataDir *controller.DataDir) (controllerCert tls.Certificate, err error) { +func FindControllerCertificate(dataDir *controller.DataDir) (tls.Certificate, error) { + // Prefer user-specified certificate and key if controllerCertPath != "" || controllerKeyPath != "" { - // if external certificate is specified, use it - if err = checkBothCertAndKeyAreSpecified(); err != nil { - return controllerCert, err + if err := checkBothCertAndKeyAreSpecified(); err != nil { + return tls.Certificate{}, err } + return tls.LoadX509KeyPair(controllerCertPath, controllerKeyPath) - } else if !dataDir.ControllerCertificateExists() { - // otherwise, generate a self-signed certificate if it's not already present - controllerCert, err = GenerateSelfSignedControllerCertificate() - if err != nil { - return controllerCert, err - } - if err = dataDir.SetControllerCertificate(controllerCert); err != nil { - return controllerCert, err - } } - return + + // Fall back to loading the certificate from the Orchard data directory + controllerCert, err := dataDir.ControllerCertificate() + if err == nil { + return controllerCert, nil + } + if err != nil && !errors.Is(err, os.ErrNotExist) { + return tls.Certificate{}, err + } + + // Fall back to generating a new certificate + controllerCert, err = GenerateSelfSignedControllerCertificate() + if err != nil { + return controllerCert, err + } + if err = dataDir.SetControllerCertificate(controllerCert); err != nil { + return controllerCert, err + } + + return controllerCert, nil +} + +func FindSSHHostKey(dataDir *controller.DataDir) (ssh.Signer, error) { + // Prefer user-specified host key + if sshHostKeyPath != "" { + hostKeyBytes, err := os.ReadFile(sshHostKeyPath) + if err != nil { + return nil, err + } + + signer, err := ssh.ParsePrivateKey(hostKeyBytes) + if err != nil { + return nil, err + } + + return signer, nil + } + + // Fall back to loading the host key from the Orchard data directory + signer, err := dataDir.SSHHostKey() + if err == nil { + return signer, err + } + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, err + } + + // Fall back to generating a new host key + _, privateKey, err := ed25519.GenerateKey(nil) + if err != nil { + return nil, err + } + + if err := dataDir.SetSSHHostKey(privateKey); err != nil { + return nil, err + } + + signer, err = ssh.NewSignerFromKey(privateKey) + if err != nil { + return nil, err + } + + return signer, nil } func checkBothCertAndKeyAreSpecified() error { diff --git a/internal/command/controller/run.go b/internal/command/controller/run.go index 8e111bb..afbd42b 100644 --- a/internal/command/controller/run.go +++ b/internal/command/controller/run.go @@ -17,6 +17,7 @@ var ErrRunFailed = errors.New("failed to run controller") var BootstrapAdminAccountName = "bootstrap-admin" var address string +var addressSSH string var debug bool func newRunCommand() *cobra.Command { @@ -31,7 +32,10 @@ func newRunCommand() *cobra.Command { port = strconv.FormatInt(netconstants.DefaultControllerPort, 10) } - cmd.PersistentFlags().StringVarP(&address, "listen", "l", fmt.Sprintf(":%s", port), "address to listen on") + cmd.PersistentFlags().StringVarP(&address, "listen", "l", fmt.Sprintf(":%s", port), + "address to listen on") + cmd.PersistentFlags().StringVar(&addressSSH, "listen-ssh", "", + "address for the built-in SSH server to listen on (e.g. \":6122\")") cmd.PersistentFlags().BoolVar(&debug, "debug", false, "enable debug logging") // flags for auto-init if necessary @@ -42,6 +46,8 @@ func newRunCommand() *cobra.Command { cmd.PersistentFlags().StringVar(&controllerKeyPath, "controller-key", "", "use the controller certificate key from the specified path instead of the auto-generated one"+ " (requires --controller-cert)") + cmd.PersistentFlags().StringVar(&sshHostKeyPath, "ssh-host-key", "", + "use the SSH private host key from the specified path instead of the auto-generated one") return cmd } @@ -74,20 +80,12 @@ func runController(cmd *cobra.Command, args []string) (err error) { return err } - var controllerCert tls.Certificate - if dataDir.ControllerCertificateExists() { - controllerCert, err = dataDir.ControllerCertificate() - if err != nil { - return err - } - } else { - controllerCert, err = FindControllerCertificate(dataDir) - if err != nil { - return err - } + controllerCert, err := FindControllerCertificate(dataDir) + if err != nil { + return err } - controllerInstance, err := controller.New( + controllerOpts := []controller.Option{ controller.WithListenAddr(address), controller.WithDataDir(dataDir), controller.WithLogger(logger), @@ -97,7 +95,18 @@ func runController(cmd *cobra.Command, args []string) (err error) { controllerCert, }, }), - ) + } + + if addressSSH != "" { + signer, err := FindSSHHostKey(dataDir) + if err != nil { + return err + } + + controllerOpts = append(controllerOpts, controller.WithSSHServer(addressSSH, signer)) + } + + controllerInstance, err := controller.New(controllerOpts...) if err != nil { return err } diff --git a/internal/command/create/vm.go b/internal/command/create/vm.go index c3aec4e..57affb7 100644 --- a/internal/command/create/vm.go +++ b/internal/command/create/vm.go @@ -3,6 +3,7 @@ package create import ( "errors" "fmt" + "github.com/cirruslabs/orchard/internal/simplename" "github.com/cirruslabs/orchard/pkg/client" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/spf13/cobra" @@ -66,6 +67,11 @@ func newCreateVMCommand() *cobra.Command { func runCreateVM(cmd *cobra.Command, args []string) error { name := args[0] + // Issue a warning if the name used will be invalid in the future + if err := simplename.ValidateNext(name); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "WARNING: %v\n", err) + } + // Convert arguments var hostDirs []v1.HostDir diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 8b5e3d6..ed7a7ef 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -8,12 +8,14 @@ import ( "github.com/cirruslabs/orchard/internal/controller/notifier" "github.com/cirruslabs/orchard/internal/controller/proxy" "github.com/cirruslabs/orchard/internal/controller/scheduler" + "github.com/cirruslabs/orchard/internal/controller/sshserver" storepkg "github.com/cirruslabs/orchard/internal/controller/store" "github.com/cirruslabs/orchard/internal/controller/store/badger" "github.com/cirruslabs/orchard/internal/netconstants" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/cirruslabs/orchard/rpc" "go.uber.org/zap" + "golang.org/x/crypto/ssh" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" "google.golang.org/grpc" @@ -53,6 +55,10 @@ type Controller struct { workerOfflineTimeout time.Duration maxWorkersPerLicense uint + sshListenAddr string + sshSigner ssh.Signer + sshServer *sshserver.SSHServer + rpc.UnimplementedControllerServer } @@ -94,16 +100,30 @@ func New(opts ...Option) (*Controller, error) { controller.logger = zap.NewNop().Sugar() } - // Instantiate controller + // Instantiate the database store, err := badger.NewBadgerStore(controller.dataDir.DBPath()) if err != nil { return nil, err } controller.store = store + + // Instantiate the worker notifier controller.workerNotifier = notifier.NewNotifier(controller.logger.With("component", "rpc")) + + // Instantiate the scheduler controller.scheduler = scheduler.NewScheduler(store, controller.workerNotifier, controller.workerOfflineTimeout, controller.logger) + // Instantiate the SSH server (if configured) + if controller.sshListenAddr != "" && controller.sshSigner != nil { + controller.sshServer, err = sshserver.NewSSHServer(controller.sshListenAddr, controller.sshSigner, + store, controller.proxy, controller.workerNotifier, controller.logger) + if err != nil { + return nil, err + } + } + + // Instantiate the controller listener, err := net.Listen("tcp", controller.listenAddr) if err != nil { return nil, err @@ -181,6 +201,11 @@ func (controller *Controller) Run(ctx context.Context) error { // be assigned to a specific Worker go controller.scheduler.Run() + // Run the SSH server (if configured) + if controller.sshServer != nil { + go controller.sshServer.Run() + } + // A helper function to shut down the HTTP server on context cancellation go func() { <-ctx.Done() @@ -206,3 +231,11 @@ func (controller *Controller) Address() string { return fmt.Sprintf("http://%s", hostPort) } + +func (controller *Controller) SSHAddress() (string, bool) { + if controller.sshServer == nil { + return "", false + } + + return controller.sshServer.Address(), true +} diff --git a/internal/controller/datadir.go b/internal/controller/datadir.go index d94377f..d16460e 100644 --- a/internal/controller/datadir.go +++ b/internal/controller/datadir.go @@ -1,25 +1,25 @@ package controller import ( + "crypto" "crypto/tls" "crypto/x509" "encoding/pem" "errors" "fmt" + "golang.org/x/crypto/ssh" "os" "path/filepath" ) -var ErrDataDirError = errors.New("controller's data directory operation error") - type DataDir struct { path string } func NewDataDir(path string) (*DataDir, error) { if err := os.MkdirAll(path, 0700); err != nil { - return nil, fmt.Errorf("%w: failed to create data directory at path %s: %v", - ErrDataDirError, path, err) + return nil, fmt.Errorf("failed to create data directory at path %s: %w", + path, err) } return &DataDir{ @@ -30,8 +30,7 @@ func NewDataDir(path string) (*DataDir, error) { func (dataDir *DataDir) ControllerCertificate() (tls.Certificate, error) { cert, err := tls.LoadX509KeyPair(dataDir.ControllerCertificatePath(), dataDir.ControllerKeyPath()) if err != nil { - return tls.Certificate{}, fmt.Errorf("%w: failed to load controller's certificate and key: %v", - ErrDataDirError, err) + return tls.Certificate{}, fmt.Errorf("failed to load controller's certificate and key: %w", err) } return cert, nil @@ -45,8 +44,8 @@ func (dataDir *DataDir) SetControllerCertificate(certificate tls.Certificate) er privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(certificate.PrivateKey) if err != nil { - return fmt.Errorf("%w: failed to set controller's certificate: PKCS #8 marshalling failed: %v", - ErrDataDirError, err) + return fmt.Errorf("failed to set controller's certificate: PKCS #8 marshalling failed: %w", + err) } privateKeyPEMBytes := pem.EncodeToMemory(&pem.Block{ @@ -56,29 +55,38 @@ func (dataDir *DataDir) SetControllerCertificate(certificate tls.Certificate) er err = os.WriteFile(dataDir.ControllerCertificatePath(), certPEMBytes, 0600) if err != nil { - return fmt.Errorf("%w: failed to write controller's certificate: %v", ErrDataDirError, err) + return fmt.Errorf("failed to write controller's certificate: %w", err) } err = os.WriteFile(dataDir.ControllerKeyPath(), privateKeyPEMBytes, 0600) if err != nil { - return fmt.Errorf("%w: failed to write controller's key: %v", ErrDataDirError, err) + return fmt.Errorf("failed to write controller's key: %w", err) } return nil } +func (dataDir *DataDir) SSHHostKey() (ssh.Signer, error) { + hostKeyBytes, err := os.ReadFile(dataDir.SSHHostKeyPath()) + if err != nil { + return nil, err + } + + return ssh.ParsePrivateKey(hostKeyBytes) +} + +func (dataDir *DataDir) SetSSHHostKey(privateKey crypto.PrivateKey) error { + pemBlock, err := ssh.MarshalPrivateKey(privateKey, "") + if err != nil { + return err + } + + return os.WriteFile(dataDir.SSHHostKeyPath(), pem.EncodeToMemory(pemBlock), 0600) +} + func (dataDir *DataDir) DBPath() string { return filepath.Join(dataDir.path, "db") } -func (dataDir *DataDir) ControllerCertificateExists() bool { - return fileExist(dataDir.ControllerCertificatePath()) && fileExist(dataDir.ControllerKeyPath()) -} - -func fileExist(path string) bool { - _, err := os.Stat(path) - return err == nil -} - func (dataDir *DataDir) ControllerCertificatePath() string { return filepath.Join(dataDir.path, "controller.crt") } @@ -87,6 +95,10 @@ func (dataDir *DataDir) ControllerKeyPath() string { return filepath.Join(dataDir.path, "controller.key") } +func (dataDir *DataDir) SSHHostKeyPath() string { + return filepath.Join(dataDir.path, "ssh_host_ed25519_key") +} + func (dataDir *DataDir) Initialized() (bool, error) { dataDirEntries, err := os.ReadDir(dataDir.path) if err != nil { @@ -94,8 +106,8 @@ func (dataDir *DataDir) Initialized() (bool, error) { return false, nil } - return false, fmt.Errorf("%w: failed to read data directory contents at path %s: %v", - ErrDataDirError, dataDir.path, err) + return false, fmt.Errorf("failed to read data directory contents at path %s: %w", + dataDir.path, err) } return len(dataDirEntries) != 0, nil diff --git a/internal/controller/option.go b/internal/controller/option.go index 61740a8..7bc91ca 100644 --- a/internal/controller/option.go +++ b/internal/controller/option.go @@ -3,6 +3,7 @@ package controller import ( "crypto/tls" "go.uber.org/zap" + "golang.org/x/crypto/ssh" "time" ) @@ -26,6 +27,13 @@ func WithTLSConfig(tlsConfig *tls.Config) Option { } } +func WithSSHServer(listenAddr string, signer ssh.Signer) Option { + return func(controller *Controller) { + controller.sshListenAddr = listenAddr + controller.sshSigner = signer + } +} + func WithInsecureAuthDisabled() Option { return func(controller *Controller) { controller.insecureAuthDisabled = true diff --git a/internal/controller/sshserver/sshserver.go b/internal/controller/sshserver/sshserver.go new file mode 100644 index 0000000..6dbfa94 --- /dev/null +++ b/internal/controller/sshserver/sshserver.go @@ -0,0 +1,284 @@ +package sshserver + +import ( + "context" + "crypto/subtle" + "errors" + "fmt" + "github.com/cirruslabs/orchard/internal/controller/notifier" + proxypkg "github.com/cirruslabs/orchard/internal/controller/proxy" + storepkg "github.com/cirruslabs/orchard/internal/controller/store" + "github.com/cirruslabs/orchard/internal/proxy" + "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/cirruslabs/orchard/rpc" + "github.com/google/uuid" + "github.com/samber/lo" + "go.uber.org/zap" + "golang.org/x/crypto/ssh" + "net" + "strings" +) + +const ( + // "ssh -J" uses channels of type "direct-tcpip", which are documented + // in the RFC 4254 (7.2. TCP/IP Forwarding Channels)[1]. + // + // [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2 + channelTypeDirectTCPIP = "direct-tcpip" +) + +type SSHServer struct { + listener net.Listener + serverConfig *ssh.ServerConfig + store storepkg.Store + proxy *proxypkg.Proxy + workerNotifier *notifier.Notifier + logger *zap.SugaredLogger +} + +func NewSSHServer( + address string, + signer ssh.Signer, + store storepkg.Store, + proxy *proxypkg.Proxy, + workerNotifier *notifier.Notifier, + logger *zap.SugaredLogger, +) (*SSHServer, error) { + server := &SSHServer{ + store: store, + proxy: proxy, + workerNotifier: workerNotifier, + logger: logger, + } + + listener, err := net.Listen("tcp", address) + if err != nil { + return nil, err + } + server.listener = listener + + server.serverConfig = &ssh.ServerConfig{ + PasswordCallback: server.passwordCallback, + } + server.serverConfig.AddHostKey(signer) + + return server, nil +} + +func (server *SSHServer) Run() { + for { + conn, err := server.listener.Accept() + if err != nil { + server.logger.Warnf("failed to accept connection: %v", err) + + continue + } + + go server.handleConnection(conn) + } +} + +func (server *SSHServer) Address() string { + return strings.ReplaceAll(server.listener.Addr().String(), "[::]", "127.0.0.1") +} + +func (server *SSHServer) passwordCallback(connMetadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if err := server.store.View(func(txn storepkg.Transaction) error { + // Authenticate + server.logger.Debugf("authenticating user %q using the password authentication", + connMetadata.User()) + + serviceAccount, err := txn.GetServiceAccount(connMetadata.User()) + if err != nil { + if errors.Is(err, storepkg.ErrNotFound) { + return fmt.Errorf("authentication failed, non-existent user %q", + connMetadata.User()) + } + + server.logger.Errorf("failed to retrieve service account %q: %v", + connMetadata.User(), err) + + return fmt.Errorf("authentication failed due to an internal error") + } + + if subtle.ConstantTimeCompare([]byte(serviceAccount.Token), password) != 1 { + return fmt.Errorf("authentication failed for user %q: invalid password", + connMetadata.User()) + } + + // Authorize + if !lo.Contains(serviceAccount.Roles, v1.ServiceAccountRoleComputeWrite) { + return fmt.Errorf("authorization failed for user %q because it lacks %q role", + connMetadata.User(), v1.ServiceAccountRoleComputeWrite) + } + + return nil + }); err != nil { + return nil, err + } + + return &ssh.Permissions{}, nil +} + +func (server *SSHServer) handleConnection(conn net.Conn) { + sshConn, newChannelCh, requestCh, err := ssh.NewServerConn(conn, server.serverConfig) + if err != nil { + server.logger.Warnf("failed to instantiate the SSH server instance to handle "+ + "the incoming connection from %s: %v", conn.RemoteAddr().String(), err) + + return + } + defer func() { + _ = sshConn.Close() + }() + + server.logger.Debugf("accepted SSH connection for user %q connecting from %q", + sshConn.User(), sshConn.RemoteAddr().String()) + + connCtx, connCtxCancel := context.WithCancel(context.Background()) + defer connCtxCancel() + + for { + select { + case newChannel, ok := <-newChannelCh: + if !ok { + return + } + + switch newChannel.ChannelType() { + case channelTypeDirectTCPIP: + server.logger.Debugf("handling a new direct TCP/IP channel for user %q connecting from %q", + sshConn.User(), sshConn.RemoteAddr().String()) + + go server.handleDirectTCPIP(connCtx, newChannel) + default: + message := fmt.Sprintf("unsupported channel type requested: %q", newChannel.ChannelType()) + + server.logger.Debugf(message) + + if err := newChannel.Reject(ssh.UnknownChannelType, message); err != nil { + server.logger.Warnf("failed to reject new channel of unsupported type %q: %v", + newChannel.ChannelType(), err) + + return + } + } + case request, ok := <-requestCh: + if !ok { + return + } + + server.logger.Debugf("refusing to service new request of type %q with payload of %d bytes", + request.Type, len(request.Payload)) + + if err := request.Reply(false, nil); err != nil { + server.logger.Warnf("failed to reply to a new request of type %q and payload of %d bytes: %v", + request.Type, len(request.Payload), err) + + return + } + } + } +} + +func (server *SSHServer) handleDirectTCPIP(ctx context.Context, newChannel ssh.NewChannel) { + // Unmarshal the payload to determine to which VM the user wants to connect to + // + // This direct TCP/IP channel's payload is documented + // in the RFC 4254 (7.2. TCP/IP Forwarding Channels)[1]. + // + // [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2 + payload := struct { + HostToConnect string + PortToConnect uint32 + OriginatorIPAddress string + OriginatorPort uint32 + }{} + + if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil { + message := fmt.Sprintf("failed to unmarshal payload: %v", err) + + server.logger.Warn(message) + + if err := newChannel.Reject(ssh.ConnectionFailed, message); err != nil { + server.logger.Warnf("failed to reject the new channel: %v", err) + } + + return + } + + server.logger.Debugf("proxying connection to %s:%d", payload.HostToConnect, payload.PortToConnect) + + // Retrieve the VM object + var vm *v1.VM + var err error + + err = server.store.View(func(txn storepkg.Transaction) error { + vm, err = txn.GetVM(payload.HostToConnect) + + return err + }) + if err != nil { + if err := newChannel.Reject(ssh.ConnectionFailed, "failed to find VM"); err != nil { + server.logger.Warnf("failed to reject the new channel due to non-existent VM %q: %v", + payload.HostToConnect, err) + } + + return + } + + // The user wants to connect to an existing VM, request and wait + // for a connection with the worker before accepting the channel + session := uuid.New().String() + boomerangConnCh, cancel := server.proxy.Request(ctx, session) + defer cancel() + + err = server.workerNotifier.Notify(ctx, vm.Worker, &rpc.WatchInstruction{ + Action: &rpc.WatchInstruction_PortForwardAction{ + PortForwardAction: &rpc.WatchInstruction_PortForward{ + Session: session, + VmUid: vm.UID, + Port: payload.PortToConnect, + }, + }, + }) + if err != nil { + server.logger.Warnf("failed to request port-forwarding from the worker %s: %v", + vm.Worker, err) + + return + } + + // Wait for the connection from worker and commence port forwarding + select { + case fromWorkerConnection := <-boomerangConnCh: + // Now that we have the connection from worker we can accept the channel + acceptedChannel, acceptedChannelRequests, err := newChannel.Accept() + if err != nil { + server.logger.Warnf("failed to accept the new channel: %v", err) + + return + } + + // Handle new requests on the accepted channel by refusing them + go func() { + req, ok := <-acceptedChannelRequests + if !ok { + return + } + + if err := req.Reply(false, nil); err != nil { + server.logger.Warnf("failed to reply to the new channel request: %v", err) + + return + } + }() + + // Commence port forwarding + if err := proxy.Connections(acceptedChannel, fromWorkerConnection); err != nil { + server.logger.Warnf("failed to port-forward: %v", err) + } + case <-ctx.Done(): + return + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 040c9a5..52551b3 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,11 +2,10 @@ package proxy import ( "io" - "net" "strings" ) -func Connections(left net.Conn, right net.Conn) (finalErr error) { +func Connections(left io.ReadWriteCloser, right io.ReadWriteCloser) (finalErr error) { leftErrCh := make(chan error, 1) rightErrCh := make(chan error, 1) diff --git a/internal/simplename/simplename.go b/internal/simplename/simplename.go index 5c42136..28068ec 100644 --- a/internal/simplename/simplename.go +++ b/internal/simplename/simplename.go @@ -2,9 +2,14 @@ package simplename import ( "errors" + "fmt" ) -var ErrNotASimpleName = errors.New("name contains restricted characters, please only use [A-Za-z0-9:-_.]") +var ( + ErrNotASimpleName = errors.New("name contains restricted characters, please only use [A-Za-z0-9:-_.]") + + ErrNotASimpleNameNext = errors.New("names with characters other than [a-z0-9-] will be deprecated in the future") +) func Validate(s string) error { for _, ch := range s { @@ -29,3 +34,53 @@ func Validate(s string) error { return nil } + +func ValidateNext(s string) error { + // Ensure that the name is not empty + if s == "" { + return fmt.Errorf("name cannot be empty") + } + + // Ensure that the name is 63 characters or fewer + if len(s) > 63 { + return fmt.Errorf("names with more than 63 characters " + + "will be depreacted in the future") + } + + // Ensure that the name starts and ends with an alphanumeric character + if !isAlphanumeric(s[0]) { + return fmt.Errorf("names not starting with an alphanumeric character " + + "will be deprecated in the future") + } + + if !isAlphanumeric(s[len(s)-1]) { + return fmt.Errorf("names not ending with an alphanumeric character " + + "will be deprecated in the future") + } + + for i := range s { + if isAlphanumeric(s[i]) { + continue + } + + if s[i] == '-' { + continue + } + + return ErrNotASimpleNameNext + } + + return nil +} + +func isAlphanumeric(ch uint8) bool { + if ch >= 'a' && ch <= 'z' { + return true + } + + if ch >= '0' && ch <= '9' { + return true + } + + return false +} diff --git a/internal/simplename/simplename_test.go b/internal/simplename/simplename_test.go index 8c36821..7d149b3 100644 --- a/internal/simplename/simplename_test.go +++ b/internal/simplename/simplename_test.go @@ -15,3 +15,23 @@ func TestValidate(t *testing.T) { require.Error(t, simplename.Validate("vm%"), "special characters") require.Error(t, simplename.Validate("😐"), "non-ASCII characters") } + +func TestValidateNext(t *testing.T) { + require.NoError(t, simplename.ValidateNext("abcdefghijklmnopqrstuvwxyz-01234567890")) + require.NoError(t, simplename.ValidateNext("vm-1")) + require.NoError(t, simplename.ValidateNext("host-local")) + require.NoError(t, simplename.ValidateNext("x")) + + require.Error(t, simplename.ValidateNext("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), + "uppercase characters") + require.Error(t, simplename.ValidateNext(".test"), "does not start with an alphanumeric character") + require.Error(t, simplename.ValidateNext("test."), "does not end with an alphanumeric character") + require.Error(t, simplename.ValidateNext("vm:1"), "special characters") + require.Error(t, simplename.ValidateNext("vm_1"), "special characters") + require.Error(t, simplename.ValidateNext("vm.1"), "special characters") + require.Error(t, simplename.ValidateNext("vm%"), "special characters") + require.Error(t, simplename.ValidateNext("😐"), "non-ASCII characters") + require.Error(t, simplename.ValidateNext(""), "empty name") + require.Error(t, simplename.ValidateNext("1234567890123456789012345678901234567890123456789012345678901234"), + "too long") +} diff --git a/internal/tests/devcontroller/devcontroller.go b/internal/tests/devcontroller/devcontroller.go new file mode 100644 index 0000000..37f534f --- /dev/null +++ b/internal/tests/devcontroller/devcontroller.go @@ -0,0 +1,57 @@ +package devcontroller + +import ( + "context" + "errors" + "github.com/cirruslabs/orchard/internal/command/dev" + "github.com/cirruslabs/orchard/internal/controller" + "github.com/cirruslabs/orchard/internal/worker" + "github.com/cirruslabs/orchard/pkg/client" + "github.com/stretchr/testify/require" + "net/http" + "testing" + "time" +) + +func StartIntegrationTestEnvironment(t *testing.T) (*client.Client, *controller.Controller, *worker.Worker) { + return StartIntegrationTestEnvironmentWithAdditionalOpts(t, nil, nil) +} + +func StartIntegrationTestEnvironmentWithAdditionalOpts( + t *testing.T, + additionalControllerOpts []controller.Option, + additionalWorkerOpts []worker.Option, +) (*client.Client, *controller.Controller, *worker.Worker) { + t.Setenv("ORCHARD_HOME", t.TempDir()) + + devController, devWorker, err := dev.CreateDevControllerAndWorker(t.TempDir(), + ":0", nil, additionalControllerOpts, additionalWorkerOpts) + require.NoError(t, err) + t.Cleanup(func() { + _ = devWorker.Close() + }) + + devContext, cancelDevFunc := context.WithCancel(context.Background()) + t.Cleanup(cancelDevFunc) + + go func() { + err := devController.Run(devContext) + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("dev controller failed: %v", err) + } + }() + + go func() { + err := devWorker.Run(devContext) + if err != nil && !errors.Is(err, context.Canceled) { + t.Errorf("dev worker failed: %v", err) + } + }() + + time.Sleep(5 * time.Second) + + devClient, err := client.New(client.WithAddress(devController.Address())) + require.NoError(t, err) + + return devClient, devController, devWorker +} diff --git a/internal/tests/integration_test.go b/internal/tests/integration_test.go index e66553b..c41a881 100644 --- a/internal/tests/integration_test.go +++ b/internal/tests/integration_test.go @@ -2,14 +2,12 @@ package tests_test import ( "context" - "errors" "fmt" - "github.com/cirruslabs/orchard/internal/command/dev" "github.com/cirruslabs/orchard/internal/controller" - "github.com/cirruslabs/orchard/internal/worker" + "github.com/cirruslabs/orchard/internal/tests/devcontroller" + "github.com/cirruslabs/orchard/internal/tests/wait" "github.com/cirruslabs/orchard/internal/worker/ondiskname" "github.com/cirruslabs/orchard/internal/worker/tart" - "github.com/cirruslabs/orchard/pkg/client" v1 "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -18,7 +16,6 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "net" - "net/http" "os" "path/filepath" "strconv" @@ -28,7 +25,7 @@ import ( ) func TestSingleVM(t *testing.T) { - devClient := StartIntegrationTestEnvironment(t) + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) workers, err := devClient.Workers().List(context.Background()) if err != nil { @@ -52,7 +49,7 @@ func TestSingleVM(t *testing.T) { if err != nil { t.Fatal(err) } - assert.True(t, Wait(2*time.Minute, func() bool { + assert.True(t, wait.Wait(2*time.Minute, func() bool { vm, err := devClient.VMs().Get(context.Background(), "test-vm") if err != nil { t.Fatal(err) @@ -66,7 +63,7 @@ func TestSingleVM(t *testing.T) { } assert.Empty(t, runningVM.StatusMessage) assert.Equal(t, v1.VMStatusRunning, runningVM.Status) - assert.True(t, Wait(2*time.Minute, func() bool { + assert.True(t, wait.Wait(2*time.Minute, func() bool { logLines, err := devClient.VMs().Logs(context.Background(), "test-vm") if err != nil { t.Fatal(err) @@ -92,7 +89,7 @@ func TestSingleVM(t *testing.T) { require.NoError(t, devClient.VMs().Delete(context.Background(), "test-vm")) // Ensure that the worker has deleted this VM from disk - assert.True(t, Wait(2*time.Minute, func() bool { + assert.True(t, wait.Wait(2*time.Minute, func() bool { t.Logf("Waiting for the VM to be garbage collected...") return !hasVMByPredicate(t, func(info tart.VMInfo) bool { @@ -102,7 +99,7 @@ func TestSingleVM(t *testing.T) { } func TestFailedStartupScript(t *testing.T) { - devClient := StartIntegrationTestEnvironment(t) + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) workers, err := devClient.Workers().List(context.Background()) if err != nil { @@ -125,7 +122,7 @@ func TestFailedStartupScript(t *testing.T) { if err != nil { t.Fatal(err) } - assert.True(t, Wait(2*time.Minute, func() bool { + assert.True(t, wait.Wait(2*time.Minute, func() bool { vm, err := devClient.VMs().Get(context.Background(), "test-vm") if err != nil { t.Fatal(err) @@ -141,72 +138,10 @@ func TestFailedStartupScript(t *testing.T) { "failed to run startup script: Process exited with status 123") } -func Wait(duration time.Duration, condition func() bool) bool { - ctx, cancel := context.WithTimeout(context.Background(), duration) - defer cancel() - for { - if condition() { - // all good - return true - } - select { - case <-ctx.Done(): - return false - case <-time.After(5 * time.Second): - // try again - continue - } - } -} - -func StartIntegrationTestEnvironment( - t *testing.T, -) *client.Client { - return StartIntegrationTestEnvironmentWithAdditionalOpts(t, nil, nil) -} - -func StartIntegrationTestEnvironmentWithAdditionalOpts( - t *testing.T, - additionalControllerOpts []controller.Option, - additionalWorkerOpts []worker.Option, -) *client.Client { - t.Setenv("ORCHARD_HOME", t.TempDir()) - devController, devWorker, err := dev.CreateDevControllerAndWorker(t.TempDir(), - ":0", nil, additionalControllerOpts, additionalWorkerOpts) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - _ = devWorker.Close() - }) - devContext, cancelDevFunc := context.WithCancel(context.Background()) - t.Cleanup(cancelDevFunc) - go func() { - err := devController.Run(devContext) - if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { - t.Errorf("dev controller failed: %v", err) - } - }() - go func() { - err := devWorker.Run(devContext) - if err != nil && !errors.Is(err, context.Canceled) { - t.Errorf("dev worker failed: %v", err) - } - }() - - time.Sleep(5 * time.Second) - - devClient, err := client.New(client.WithAddress(devController.Address())) - if err != nil { - t.Fatal(err) - } - return devClient -} - func TestPortForwarding(t *testing.T) { ctx := context.Background() - devClient := StartIntegrationTestEnvironment(t) + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) // Create a generic macOS VM err := devClient.VMs().Create(ctx, &v1.VM{ @@ -261,7 +196,7 @@ func TestPortForwarding(t *testing.T) { func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) { ctx := context.Background() - devClient := StartIntegrationTestEnvironment(t) + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) const ( dummyWorkerName = "dummy-worker" @@ -304,7 +239,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) { require.NoError(t, err) // Wait for the dummy VM to get scheduled to a dummy worker - require.True(t, Wait(2*time.Minute, func() bool { + require.True(t, wait.Wait(2*time.Minute, func() bool { vm, err := devClient.VMs().Get(context.Background(), dummyVMName) require.NoError(t, err) @@ -318,7 +253,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) { require.NoError(t, err) // Wait for the scheduler to change the dummy VM's status to "failed" - require.True(t, Wait(2*time.Minute, func() bool { + require.True(t, wait.Wait(2*time.Minute, func() bool { vm, err := devClient.VMs().Get(context.Background(), dummyVMName) require.NoError(t, err) @@ -340,7 +275,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) { func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) { ctx := context.Background() - devClient := StartIntegrationTestEnvironmentWithAdditionalOpts(t, + devClient, _, _ := devcontroller.StartIntegrationTestEnvironmentWithAdditionalOpts(t, []controller.Option{controller.WithWorkerOfflineTimeout(1 * time.Minute)}, nil) const ( @@ -379,7 +314,7 @@ func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) { require.NoError(t, err) // Wait for the VM to be marked as failed - assert.True(t, Wait(2*time.Minute, func() bool { + assert.True(t, wait.Wait(2*time.Minute, func() bool { vm, err := devClient.VMs().Get(context.Background(), dummyVMName) require.NoError(t, err) @@ -415,10 +350,10 @@ func TestVMGarbageCollection(t *testing.T) { require.True(t, hasVM(t, vmName, logger)) // Start the Orchard Worker - _ = StartIntegrationTestEnvironment(t) + _, _, _ = devcontroller.StartIntegrationTestEnvironment(t) // Wait for the Orchard Worker to garbage-collect this VM - require.True(t, Wait(2*time.Minute, func() bool { + require.True(t, wait.Wait(2*time.Minute, func() bool { t.Logf("Waiting for the on-disk VM to be cleaned up by the worker") return !hasVM(t, vmName, logger) @@ -426,7 +361,7 @@ func TestVMGarbageCollection(t *testing.T) { } func TestHostDirs(t *testing.T) { - devClient := StartIntegrationTestEnvironment(t) + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) dirToMount := t.TempDir() @@ -461,7 +396,7 @@ func TestHostDirs(t *testing.T) { var vm *v1.VM - require.True(t, Wait(2*time.Minute, func() bool { + require.True(t, wait.Wait(2*time.Minute, func() bool { vm, err = devClient.VMs().Get(context.Background(), vmName) require.NoError(t, err) @@ -475,7 +410,7 @@ func TestHostDirs(t *testing.T) { var logLines []string - require.True(t, Wait(2*time.Minute, func() bool { + require.True(t, wait.Wait(2*time.Minute, func() bool { logLines, err = devClient.VMs().Logs(context.Background(), vmName) require.NoError(t, err) @@ -495,7 +430,7 @@ func TestHostDirs(t *testing.T) { } func TestHostDirsInvalidPolicy(t *testing.T) { - devClient := StartIntegrationTestEnvironment(t) + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) dirToMount := t.TempDir() diff --git a/internal/tests/sshserver_test.go b/internal/tests/sshserver_test.go new file mode 100644 index 0000000..4668663 --- /dev/null +++ b/internal/tests/sshserver_test.go @@ -0,0 +1,109 @@ +package tests_test + +import ( + "context" + "crypto/subtle" + "fmt" + "github.com/cirruslabs/orchard/internal/controller" + "github.com/cirruslabs/orchard/internal/tests/devcontroller" + "github.com/cirruslabs/orchard/internal/tests/wait" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" + "net" + "testing" + "time" +) + +func TestSSHServer(t *testing.T) { + // Generate SSH host key for the Controller + publicKey, privateKey, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + sshPublicKey, err := ssh.NewPublicKey(publicKey) + require.NoError(t, err) + + signer, err := ssh.NewSignerFromKey(privateKey) + require.NoError(t, err) + + // Run the Controller + devClient, devController, _ := devcontroller.StartIntegrationTestEnvironmentWithAdditionalOpts(t, []controller.Option{ + controller.WithSSHServer(":0", signer), + }, nil) + + // Create a VM to which we'll connect via Controller's SSH server + err = devClient.VMs().Create(context.Background(), &v1.VM{ + Meta: v1.Meta{ + Name: "test-vm", + }, + Image: "ghcr.io/cirruslabs/macos-sonoma-base:latest", + CPU: 4, + Memory: 8 * 1024, + Headless: true, + Status: v1.VMStatusPending, + }) + require.NoError(t, err) + + // Wait for the VM to start + assert.True(t, wait.Wait(2*time.Minute, func() bool { + vm, err := devClient.VMs().Get(context.Background(), "test-vm") + require.NoError(t, err) + + return vm.Status == v1.VMStatusRunning + }), "failed to wait for the VM to start") + + // Create a service account whose credentials we'll use to connect to the Controller's SSH server + require.NoError(t, devClient.ServiceAccounts().Create(context.Background(), &v1.ServiceAccount{ + Meta: v1.Meta{ + Name: "ssh-user", + }, + Token: "ssh-password", + Roles: []v1.ServiceAccountRole{ + v1.ServiceAccountRoleComputeWrite, + }, + })) + + // Connect to the VM over Orchard Controller's SSH server + sshAddress, ok := devController.SSHAddress() + require.True(t, ok) + + sshClientController, err := ssh.Dial("tcp", sshAddress, &ssh.ClientConfig{ + User: "ssh-user", + Auth: []ssh.AuthMethod{ + ssh.Password("ssh-password"), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if subtle.ConstantTimeCompare(sshPublicKey.Marshal(), key.Marshal()) != 1 { + return fmt.Errorf("untrustred public key was presented by the Controller") + } + + return nil + }, + }) + require.NoError(t, err) + + netConnVM, err := sshClientController.Dial("tcp", "test-vm:22") + require.NoError(t, err) + + sshConnVM, sshChansVM, sshReqsVM, err := ssh.NewClientConn(netConnVM, "test-vm:22", &ssh.ClientConfig{ + User: "admin", + Auth: []ssh.AuthMethod{ + ssh.Password("admin"), + }, + HostKeyCallback: func(_ string, _ net.Addr, _ ssh.PublicKey) error { + return nil + }, + }) + require.NoError(t, err) + + sshClientVM := ssh.NewClient(sshConnVM, sshChansVM, sshReqsVM) + + sshSessVM, err := sshClientVM.NewSession() + require.NoError(t, err) + + unameBytes, err := sshSessVM.Output("uname -a") + require.NoError(t, err) + require.Contains(t, string(unameBytes), "Darwin") +} diff --git a/internal/tests/wait/wait.go b/internal/tests/wait/wait.go new file mode 100644 index 0000000..a736170 --- /dev/null +++ b/internal/tests/wait/wait.go @@ -0,0 +1,24 @@ +package wait + +import ( + "context" + "time" +) + +func Wait(duration time.Duration, predicate func() bool) bool { + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + + for { + if predicate() { + return true + } + + select { + case <-ctx.Done(): + return false + case <-time.After(1 * time.Second): + continue + } + } +}