Orchard Controller: implement an SSH server that acts as a jump host (#179)
* proxy.Connections(): require io.ReadWriteCloser instead of net.Conn * Orchard Controller: implement an SSH server that acts as a jump host * Issue a warning if the name used will be invalid in the future * Further restrict uppercase characters in names in the future The rationale is similar to https://github.com/kubernetes/kubernetes/issues/71140. We won't want to munge the user's input and introduce subtle bugs doing lowercase comparisons.
This commit is contained in:
parent
c845f3b2fd
commit
d59bc7f8a7
|
|
@ -2,8 +2,8 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cirruslabs/orchard/internal/command"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
)
|
||||
|
|
@ -25,6 +25,8 @@ func main() {
|
|||
|
||||
// Run the command
|
||||
if err := command.NewRootCmd().ExecuteContext(ctx); err != nil {
|
||||
log.Fatal(err)
|
||||
_, _ = fmt.Fprintln(os.Stderr, err)
|
||||
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
8
go.mod
8
go.mod
|
|
@ -26,10 +26,10 @@ require (
|
|||
github.com/spf13/cobra v1.6.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
go.uber.org/zap v1.24.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29
|
||||
golang.org/x/net v0.23.0
|
||||
golang.org/x/term v0.18.0
|
||||
golang.org/x/term v0.20.0
|
||||
google.golang.org/grpc v1.56.3
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
|
|
@ -102,8 +102,8 @@ require (
|
|||
go.uber.org/atomic v1.10.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
|
|
|||
16
go.sum
16
go.sum
|
|
@ -497,8 +497,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
|
|
@ -629,12 +629,12 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
|
||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
|
@ -643,8 +643,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
|
|
|||
|
|
@ -12,7 +12,10 @@ import (
|
|||
"fmt"
|
||||
"github.com/cirruslabs/orchard/internal/controller"
|
||||
"github.com/cirruslabs/orchard/internal/netconstants"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -20,25 +23,80 @@ var ErrInitFailed = errors.New("controller initialization failed")
|
|||
|
||||
var controllerCertPath string
|
||||
var controllerKeyPath string
|
||||
var sshHostKeyPath string
|
||||
|
||||
func FindControllerCertificate(dataDir *controller.DataDir) (controllerCert tls.Certificate, err error) {
|
||||
func FindControllerCertificate(dataDir *controller.DataDir) (tls.Certificate, error) {
|
||||
// Prefer user-specified certificate and key
|
||||
if controllerCertPath != "" || controllerKeyPath != "" {
|
||||
// if external certificate is specified, use it
|
||||
if err = checkBothCertAndKeyAreSpecified(); err != nil {
|
||||
return controllerCert, err
|
||||
if err := checkBothCertAndKeyAreSpecified(); err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
return tls.LoadX509KeyPair(controllerCertPath, controllerKeyPath)
|
||||
} else if !dataDir.ControllerCertificateExists() {
|
||||
// otherwise, generate a self-signed certificate if it's not already present
|
||||
controllerCert, err = GenerateSelfSignedControllerCertificate()
|
||||
if err != nil {
|
||||
return controllerCert, err
|
||||
}
|
||||
if err = dataDir.SetControllerCertificate(controllerCert); err != nil {
|
||||
return controllerCert, err
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
// Fall back to loading the certificate from the Orchard data directory
|
||||
controllerCert, err := dataDir.ControllerCertificate()
|
||||
if err == nil {
|
||||
return controllerCert, nil
|
||||
}
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
// Fall back to generating a new certificate
|
||||
controllerCert, err = GenerateSelfSignedControllerCertificate()
|
||||
if err != nil {
|
||||
return controllerCert, err
|
||||
}
|
||||
if err = dataDir.SetControllerCertificate(controllerCert); err != nil {
|
||||
return controllerCert, err
|
||||
}
|
||||
|
||||
return controllerCert, nil
|
||||
}
|
||||
|
||||
func FindSSHHostKey(dataDir *controller.DataDir) (ssh.Signer, error) {
|
||||
// Prefer user-specified host key
|
||||
if sshHostKeyPath != "" {
|
||||
hostKeyBytes, err := os.ReadFile(sshHostKeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(hostKeyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
// Fall back to loading the host key from the Orchard data directory
|
||||
signer, err := dataDir.SSHHostKey()
|
||||
if err == nil {
|
||||
return signer, err
|
||||
}
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fall back to generating a new host key
|
||||
_, privateKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := dataDir.SetSSHHostKey(privateKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signer, err = ssh.NewSignerFromKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func checkBothCertAndKeyAreSpecified() error {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ var ErrRunFailed = errors.New("failed to run controller")
|
|||
var BootstrapAdminAccountName = "bootstrap-admin"
|
||||
|
||||
var address string
|
||||
var addressSSH string
|
||||
var debug bool
|
||||
|
||||
func newRunCommand() *cobra.Command {
|
||||
|
|
@ -31,7 +32,10 @@ func newRunCommand() *cobra.Command {
|
|||
port = strconv.FormatInt(netconstants.DefaultControllerPort, 10)
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().StringVarP(&address, "listen", "l", fmt.Sprintf(":%s", port), "address to listen on")
|
||||
cmd.PersistentFlags().StringVarP(&address, "listen", "l", fmt.Sprintf(":%s", port),
|
||||
"address to listen on")
|
||||
cmd.PersistentFlags().StringVar(&addressSSH, "listen-ssh", "",
|
||||
"address for the built-in SSH server to listen on (e.g. \":6122\")")
|
||||
cmd.PersistentFlags().BoolVar(&debug, "debug", false, "enable debug logging")
|
||||
|
||||
// flags for auto-init if necessary
|
||||
|
|
@ -42,6 +46,8 @@ func newRunCommand() *cobra.Command {
|
|||
cmd.PersistentFlags().StringVar(&controllerKeyPath, "controller-key", "",
|
||||
"use the controller certificate key from the specified path instead of the auto-generated one"+
|
||||
" (requires --controller-cert)")
|
||||
cmd.PersistentFlags().StringVar(&sshHostKeyPath, "ssh-host-key", "",
|
||||
"use the SSH private host key from the specified path instead of the auto-generated one")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
|
@ -74,20 +80,12 @@ func runController(cmd *cobra.Command, args []string) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
var controllerCert tls.Certificate
|
||||
if dataDir.ControllerCertificateExists() {
|
||||
controllerCert, err = dataDir.ControllerCertificate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
controllerCert, err = FindControllerCertificate(dataDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
controllerCert, err := FindControllerCertificate(dataDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
controllerInstance, err := controller.New(
|
||||
controllerOpts := []controller.Option{
|
||||
controller.WithListenAddr(address),
|
||||
controller.WithDataDir(dataDir),
|
||||
controller.WithLogger(logger),
|
||||
|
|
@ -97,7 +95,18 @@ func runController(cmd *cobra.Command, args []string) (err error) {
|
|||
controllerCert,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
if addressSSH != "" {
|
||||
signer, err := FindSSHHostKey(dataDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
controllerOpts = append(controllerOpts, controller.WithSSHServer(addressSSH, signer))
|
||||
}
|
||||
|
||||
controllerInstance, err := controller.New(controllerOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package create
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cirruslabs/orchard/internal/simplename"
|
||||
"github.com/cirruslabs/orchard/pkg/client"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/spf13/cobra"
|
||||
|
|
@ -66,6 +67,11 @@ func newCreateVMCommand() *cobra.Command {
|
|||
func runCreateVM(cmd *cobra.Command, args []string) error {
|
||||
name := args[0]
|
||||
|
||||
// Issue a warning if the name used will be invalid in the future
|
||||
if err := simplename.ValidateNext(name); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "WARNING: %v\n", err)
|
||||
}
|
||||
|
||||
// Convert arguments
|
||||
var hostDirs []v1.HostDir
|
||||
|
||||
|
|
|
|||
|
|
@ -8,12 +8,14 @@ import (
|
|||
"github.com/cirruslabs/orchard/internal/controller/notifier"
|
||||
"github.com/cirruslabs/orchard/internal/controller/proxy"
|
||||
"github.com/cirruslabs/orchard/internal/controller/scheduler"
|
||||
"github.com/cirruslabs/orchard/internal/controller/sshserver"
|
||||
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
|
||||
"github.com/cirruslabs/orchard/internal/controller/store/badger"
|
||||
"github.com/cirruslabs/orchard/internal/netconstants"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/cirruslabs/orchard/rpc"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
"google.golang.org/grpc"
|
||||
|
|
@ -53,6 +55,10 @@ type Controller struct {
|
|||
workerOfflineTimeout time.Duration
|
||||
maxWorkersPerLicense uint
|
||||
|
||||
sshListenAddr string
|
||||
sshSigner ssh.Signer
|
||||
sshServer *sshserver.SSHServer
|
||||
|
||||
rpc.UnimplementedControllerServer
|
||||
}
|
||||
|
||||
|
|
@ -94,16 +100,30 @@ func New(opts ...Option) (*Controller, error) {
|
|||
controller.logger = zap.NewNop().Sugar()
|
||||
}
|
||||
|
||||
// Instantiate controller
|
||||
// Instantiate the database
|
||||
store, err := badger.NewBadgerStore(controller.dataDir.DBPath())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
controller.store = store
|
||||
|
||||
// Instantiate the worker notifier
|
||||
controller.workerNotifier = notifier.NewNotifier(controller.logger.With("component", "rpc"))
|
||||
|
||||
// Instantiate the scheduler
|
||||
controller.scheduler = scheduler.NewScheduler(store, controller.workerNotifier,
|
||||
controller.workerOfflineTimeout, controller.logger)
|
||||
|
||||
// Instantiate the SSH server (if configured)
|
||||
if controller.sshListenAddr != "" && controller.sshSigner != nil {
|
||||
controller.sshServer, err = sshserver.NewSSHServer(controller.sshListenAddr, controller.sshSigner,
|
||||
store, controller.proxy, controller.workerNotifier, controller.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiate the controller
|
||||
listener, err := net.Listen("tcp", controller.listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -181,6 +201,11 @@ func (controller *Controller) Run(ctx context.Context) error {
|
|||
// be assigned to a specific Worker
|
||||
go controller.scheduler.Run()
|
||||
|
||||
// Run the SSH server (if configured)
|
||||
if controller.sshServer != nil {
|
||||
go controller.sshServer.Run()
|
||||
}
|
||||
|
||||
// A helper function to shut down the HTTP server on context cancellation
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
|
@ -206,3 +231,11 @@ func (controller *Controller) Address() string {
|
|||
|
||||
return fmt.Sprintf("http://%s", hostPort)
|
||||
}
|
||||
|
||||
func (controller *Controller) SSHAddress() (string, bool) {
|
||||
if controller.sshServer == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return controller.sshServer.Address(), true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,25 +1,25 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
var ErrDataDirError = errors.New("controller's data directory operation error")
|
||||
|
||||
type DataDir struct {
|
||||
path string
|
||||
}
|
||||
|
||||
func NewDataDir(path string) (*DataDir, error) {
|
||||
if err := os.MkdirAll(path, 0700); err != nil {
|
||||
return nil, fmt.Errorf("%w: failed to create data directory at path %s: %v",
|
||||
ErrDataDirError, path, err)
|
||||
return nil, fmt.Errorf("failed to create data directory at path %s: %w",
|
||||
path, err)
|
||||
}
|
||||
|
||||
return &DataDir{
|
||||
|
|
@ -30,8 +30,7 @@ func NewDataDir(path string) (*DataDir, error) {
|
|||
func (dataDir *DataDir) ControllerCertificate() (tls.Certificate, error) {
|
||||
cert, err := tls.LoadX509KeyPair(dataDir.ControllerCertificatePath(), dataDir.ControllerKeyPath())
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("%w: failed to load controller's certificate and key: %v",
|
||||
ErrDataDirError, err)
|
||||
return tls.Certificate{}, fmt.Errorf("failed to load controller's certificate and key: %w", err)
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
|
|
@ -45,8 +44,8 @@ func (dataDir *DataDir) SetControllerCertificate(certificate tls.Certificate) er
|
|||
|
||||
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(certificate.PrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: failed to set controller's certificate: PKCS #8 marshalling failed: %v",
|
||||
ErrDataDirError, err)
|
||||
return fmt.Errorf("failed to set controller's certificate: PKCS #8 marshalling failed: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
privateKeyPEMBytes := pem.EncodeToMemory(&pem.Block{
|
||||
|
|
@ -56,29 +55,38 @@ func (dataDir *DataDir) SetControllerCertificate(certificate tls.Certificate) er
|
|||
|
||||
err = os.WriteFile(dataDir.ControllerCertificatePath(), certPEMBytes, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: failed to write controller's certificate: %v", ErrDataDirError, err)
|
||||
return fmt.Errorf("failed to write controller's certificate: %w", err)
|
||||
}
|
||||
err = os.WriteFile(dataDir.ControllerKeyPath(), privateKeyPEMBytes, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: failed to write controller's key: %v", ErrDataDirError, err)
|
||||
return fmt.Errorf("failed to write controller's key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) SSHHostKey() (ssh.Signer, error) {
|
||||
hostKeyBytes, err := os.ReadFile(dataDir.SSHHostKeyPath())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ssh.ParsePrivateKey(hostKeyBytes)
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) SetSSHHostKey(privateKey crypto.PrivateKey) error {
|
||||
pemBlock, err := ssh.MarshalPrivateKey(privateKey, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(dataDir.SSHHostKeyPath(), pem.EncodeToMemory(pemBlock), 0600)
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) DBPath() string {
|
||||
return filepath.Join(dataDir.path, "db")
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) ControllerCertificateExists() bool {
|
||||
return fileExist(dataDir.ControllerCertificatePath()) && fileExist(dataDir.ControllerKeyPath())
|
||||
}
|
||||
|
||||
func fileExist(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) ControllerCertificatePath() string {
|
||||
return filepath.Join(dataDir.path, "controller.crt")
|
||||
}
|
||||
|
|
@ -87,6 +95,10 @@ func (dataDir *DataDir) ControllerKeyPath() string {
|
|||
return filepath.Join(dataDir.path, "controller.key")
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) SSHHostKeyPath() string {
|
||||
return filepath.Join(dataDir.path, "ssh_host_ed25519_key")
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) Initialized() (bool, error) {
|
||||
dataDirEntries, err := os.ReadDir(dataDir.path)
|
||||
if err != nil {
|
||||
|
|
@ -94,8 +106,8 @@ func (dataDir *DataDir) Initialized() (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("%w: failed to read data directory contents at path %s: %v",
|
||||
ErrDataDirError, dataDir.path, err)
|
||||
return false, fmt.Errorf("failed to read data directory contents at path %s: %w",
|
||||
dataDir.path, err)
|
||||
}
|
||||
|
||||
return len(dataDirEntries) != 0, nil
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package controller
|
|||
import (
|
||||
"crypto/tls"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -26,6 +27,13 @@ func WithTLSConfig(tlsConfig *tls.Config) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithSSHServer(listenAddr string, signer ssh.Signer) Option {
|
||||
return func(controller *Controller) {
|
||||
controller.sshListenAddr = listenAddr
|
||||
controller.sshSigner = signer
|
||||
}
|
||||
}
|
||||
|
||||
func WithInsecureAuthDisabled() Option {
|
||||
return func(controller *Controller) {
|
||||
controller.insecureAuthDisabled = true
|
||||
|
|
|
|||
|
|
@ -0,0 +1,284 @@
|
|||
package sshserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cirruslabs/orchard/internal/controller/notifier"
|
||||
proxypkg "github.com/cirruslabs/orchard/internal/controller/proxy"
|
||||
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
|
||||
"github.com/cirruslabs/orchard/internal/proxy"
|
||||
"github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/cirruslabs/orchard/rpc"
|
||||
"github.com/google/uuid"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// "ssh -J" uses channels of type "direct-tcpip", which are documented
|
||||
// in the RFC 4254 (7.2. TCP/IP Forwarding Channels)[1].
|
||||
//
|
||||
// [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2
|
||||
channelTypeDirectTCPIP = "direct-tcpip"
|
||||
)
|
||||
|
||||
type SSHServer struct {
|
||||
listener net.Listener
|
||||
serverConfig *ssh.ServerConfig
|
||||
store storepkg.Store
|
||||
proxy *proxypkg.Proxy
|
||||
workerNotifier *notifier.Notifier
|
||||
logger *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func NewSSHServer(
|
||||
address string,
|
||||
signer ssh.Signer,
|
||||
store storepkg.Store,
|
||||
proxy *proxypkg.Proxy,
|
||||
workerNotifier *notifier.Notifier,
|
||||
logger *zap.SugaredLogger,
|
||||
) (*SSHServer, error) {
|
||||
server := &SSHServer{
|
||||
store: store,
|
||||
proxy: proxy,
|
||||
workerNotifier: workerNotifier,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.listener = listener
|
||||
|
||||
server.serverConfig = &ssh.ServerConfig{
|
||||
PasswordCallback: server.passwordCallback,
|
||||
}
|
||||
server.serverConfig.AddHostKey(signer)
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (server *SSHServer) Run() {
|
||||
for {
|
||||
conn, err := server.listener.Accept()
|
||||
if err != nil {
|
||||
server.logger.Warnf("failed to accept connection: %v", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
go server.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *SSHServer) Address() string {
|
||||
return strings.ReplaceAll(server.listener.Addr().String(), "[::]", "127.0.0.1")
|
||||
}
|
||||
|
||||
func (server *SSHServer) passwordCallback(connMetadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if err := server.store.View(func(txn storepkg.Transaction) error {
|
||||
// Authenticate
|
||||
server.logger.Debugf("authenticating user %q using the password authentication",
|
||||
connMetadata.User())
|
||||
|
||||
serviceAccount, err := txn.GetServiceAccount(connMetadata.User())
|
||||
if err != nil {
|
||||
if errors.Is(err, storepkg.ErrNotFound) {
|
||||
return fmt.Errorf("authentication failed, non-existent user %q",
|
||||
connMetadata.User())
|
||||
}
|
||||
|
||||
server.logger.Errorf("failed to retrieve service account %q: %v",
|
||||
connMetadata.User(), err)
|
||||
|
||||
return fmt.Errorf("authentication failed due to an internal error")
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare([]byte(serviceAccount.Token), password) != 1 {
|
||||
return fmt.Errorf("authentication failed for user %q: invalid password",
|
||||
connMetadata.User())
|
||||
}
|
||||
|
||||
// Authorize
|
||||
if !lo.Contains(serviceAccount.Roles, v1.ServiceAccountRoleComputeWrite) {
|
||||
return fmt.Errorf("authorization failed for user %q because it lacks %q role",
|
||||
connMetadata.User(), v1.ServiceAccountRoleComputeWrite)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ssh.Permissions{}, nil
|
||||
}
|
||||
|
||||
func (server *SSHServer) handleConnection(conn net.Conn) {
|
||||
sshConn, newChannelCh, requestCh, err := ssh.NewServerConn(conn, server.serverConfig)
|
||||
if err != nil {
|
||||
server.logger.Warnf("failed to instantiate the SSH server instance to handle "+
|
||||
"the incoming connection from %s: %v", conn.RemoteAddr().String(), err)
|
||||
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = sshConn.Close()
|
||||
}()
|
||||
|
||||
server.logger.Debugf("accepted SSH connection for user %q connecting from %q",
|
||||
sshConn.User(), sshConn.RemoteAddr().String())
|
||||
|
||||
connCtx, connCtxCancel := context.WithCancel(context.Background())
|
||||
defer connCtxCancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case newChannel, ok := <-newChannelCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch newChannel.ChannelType() {
|
||||
case channelTypeDirectTCPIP:
|
||||
server.logger.Debugf("handling a new direct TCP/IP channel for user %q connecting from %q",
|
||||
sshConn.User(), sshConn.RemoteAddr().String())
|
||||
|
||||
go server.handleDirectTCPIP(connCtx, newChannel)
|
||||
default:
|
||||
message := fmt.Sprintf("unsupported channel type requested: %q", newChannel.ChannelType())
|
||||
|
||||
server.logger.Debugf(message)
|
||||
|
||||
if err := newChannel.Reject(ssh.UnknownChannelType, message); err != nil {
|
||||
server.logger.Warnf("failed to reject new channel of unsupported type %q: %v",
|
||||
newChannel.ChannelType(), err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
case request, ok := <-requestCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
server.logger.Debugf("refusing to service new request of type %q with payload of %d bytes",
|
||||
request.Type, len(request.Payload))
|
||||
|
||||
if err := request.Reply(false, nil); err != nil {
|
||||
server.logger.Warnf("failed to reply to a new request of type %q and payload of %d bytes: %v",
|
||||
request.Type, len(request.Payload), err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (server *SSHServer) handleDirectTCPIP(ctx context.Context, newChannel ssh.NewChannel) {
|
||||
// Unmarshal the payload to determine to which VM the user wants to connect to
|
||||
//
|
||||
// This direct TCP/IP channel's payload is documented
|
||||
// in the RFC 4254 (7.2. TCP/IP Forwarding Channels)[1].
|
||||
//
|
||||
// [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2
|
||||
payload := struct {
|
||||
HostToConnect string
|
||||
PortToConnect uint32
|
||||
OriginatorIPAddress string
|
||||
OriginatorPort uint32
|
||||
}{}
|
||||
|
||||
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil {
|
||||
message := fmt.Sprintf("failed to unmarshal payload: %v", err)
|
||||
|
||||
server.logger.Warn(message)
|
||||
|
||||
if err := newChannel.Reject(ssh.ConnectionFailed, message); err != nil {
|
||||
server.logger.Warnf("failed to reject the new channel: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
server.logger.Debugf("proxying connection to %s:%d", payload.HostToConnect, payload.PortToConnect)
|
||||
|
||||
// Retrieve the VM object
|
||||
var vm *v1.VM
|
||||
var err error
|
||||
|
||||
err = server.store.View(func(txn storepkg.Transaction) error {
|
||||
vm, err = txn.GetVM(payload.HostToConnect)
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
if err := newChannel.Reject(ssh.ConnectionFailed, "failed to find VM"); err != nil {
|
||||
server.logger.Warnf("failed to reject the new channel due to non-existent VM %q: %v",
|
||||
payload.HostToConnect, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// The user wants to connect to an existing VM, request and wait
|
||||
// for a connection with the worker before accepting the channel
|
||||
session := uuid.New().String()
|
||||
boomerangConnCh, cancel := server.proxy.Request(ctx, session)
|
||||
defer cancel()
|
||||
|
||||
err = server.workerNotifier.Notify(ctx, vm.Worker, &rpc.WatchInstruction{
|
||||
Action: &rpc.WatchInstruction_PortForwardAction{
|
||||
PortForwardAction: &rpc.WatchInstruction_PortForward{
|
||||
Session: session,
|
||||
VmUid: vm.UID,
|
||||
Port: payload.PortToConnect,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
server.logger.Warnf("failed to request port-forwarding from the worker %s: %v",
|
||||
vm.Worker, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for the connection from worker and commence port forwarding
|
||||
select {
|
||||
case fromWorkerConnection := <-boomerangConnCh:
|
||||
// Now that we have the connection from worker we can accept the channel
|
||||
acceptedChannel, acceptedChannelRequests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
server.logger.Warnf("failed to accept the new channel: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Handle new requests on the accepted channel by refusing them
|
||||
go func() {
|
||||
req, ok := <-acceptedChannelRequests
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := req.Reply(false, nil); err != nil {
|
||||
server.logger.Warnf("failed to reply to the new channel request: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// Commence port forwarding
|
||||
if err := proxy.Connections(acceptedChannel, fromWorkerConnection); err != nil {
|
||||
server.logger.Warnf("failed to port-forward: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -2,11 +2,10 @@ package proxy
|
|||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Connections(left net.Conn, right net.Conn) (finalErr error) {
|
||||
func Connections(left io.ReadWriteCloser, right io.ReadWriteCloser) (finalErr error) {
|
||||
leftErrCh := make(chan error, 1)
|
||||
rightErrCh := make(chan error, 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,14 @@ package simplename
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var ErrNotASimpleName = errors.New("name contains restricted characters, please only use [A-Za-z0-9:-_.]")
|
||||
var (
|
||||
ErrNotASimpleName = errors.New("name contains restricted characters, please only use [A-Za-z0-9:-_.]")
|
||||
|
||||
ErrNotASimpleNameNext = errors.New("names with characters other than [a-z0-9-] will be deprecated in the future")
|
||||
)
|
||||
|
||||
func Validate(s string) error {
|
||||
for _, ch := range s {
|
||||
|
|
@ -29,3 +34,53 @@ func Validate(s string) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateNext(s string) error {
|
||||
// Ensure that the name is not empty
|
||||
if s == "" {
|
||||
return fmt.Errorf("name cannot be empty")
|
||||
}
|
||||
|
||||
// Ensure that the name is 63 characters or fewer
|
||||
if len(s) > 63 {
|
||||
return fmt.Errorf("names with more than 63 characters " +
|
||||
"will be depreacted in the future")
|
||||
}
|
||||
|
||||
// Ensure that the name starts and ends with an alphanumeric character
|
||||
if !isAlphanumeric(s[0]) {
|
||||
return fmt.Errorf("names not starting with an alphanumeric character " +
|
||||
"will be deprecated in the future")
|
||||
}
|
||||
|
||||
if !isAlphanumeric(s[len(s)-1]) {
|
||||
return fmt.Errorf("names not ending with an alphanumeric character " +
|
||||
"will be deprecated in the future")
|
||||
}
|
||||
|
||||
for i := range s {
|
||||
if isAlphanumeric(s[i]) {
|
||||
continue
|
||||
}
|
||||
|
||||
if s[i] == '-' {
|
||||
continue
|
||||
}
|
||||
|
||||
return ErrNotASimpleNameNext
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAlphanumeric(ch uint8) bool {
|
||||
if ch >= 'a' && ch <= 'z' {
|
||||
return true
|
||||
}
|
||||
|
||||
if ch >= '0' && ch <= '9' {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,3 +15,23 @@ func TestValidate(t *testing.T) {
|
|||
require.Error(t, simplename.Validate("vm%"), "special characters")
|
||||
require.Error(t, simplename.Validate("😐"), "non-ASCII characters")
|
||||
}
|
||||
|
||||
func TestValidateNext(t *testing.T) {
|
||||
require.NoError(t, simplename.ValidateNext("abcdefghijklmnopqrstuvwxyz-01234567890"))
|
||||
require.NoError(t, simplename.ValidateNext("vm-1"))
|
||||
require.NoError(t, simplename.ValidateNext("host-local"))
|
||||
require.NoError(t, simplename.ValidateNext("x"))
|
||||
|
||||
require.Error(t, simplename.ValidateNext("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
|
||||
"uppercase characters")
|
||||
require.Error(t, simplename.ValidateNext(".test"), "does not start with an alphanumeric character")
|
||||
require.Error(t, simplename.ValidateNext("test."), "does not end with an alphanumeric character")
|
||||
require.Error(t, simplename.ValidateNext("vm:1"), "special characters")
|
||||
require.Error(t, simplename.ValidateNext("vm_1"), "special characters")
|
||||
require.Error(t, simplename.ValidateNext("vm.1"), "special characters")
|
||||
require.Error(t, simplename.ValidateNext("vm%"), "special characters")
|
||||
require.Error(t, simplename.ValidateNext("😐"), "non-ASCII characters")
|
||||
require.Error(t, simplename.ValidateNext(""), "empty name")
|
||||
require.Error(t, simplename.ValidateNext("1234567890123456789012345678901234567890123456789012345678901234"),
|
||||
"too long")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
package devcontroller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/cirruslabs/orchard/internal/command/dev"
|
||||
"github.com/cirruslabs/orchard/internal/controller"
|
||||
"github.com/cirruslabs/orchard/internal/worker"
|
||||
"github.com/cirruslabs/orchard/pkg/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func StartIntegrationTestEnvironment(t *testing.T) (*client.Client, *controller.Controller, *worker.Worker) {
|
||||
return StartIntegrationTestEnvironmentWithAdditionalOpts(t, nil, nil)
|
||||
}
|
||||
|
||||
func StartIntegrationTestEnvironmentWithAdditionalOpts(
|
||||
t *testing.T,
|
||||
additionalControllerOpts []controller.Option,
|
||||
additionalWorkerOpts []worker.Option,
|
||||
) (*client.Client, *controller.Controller, *worker.Worker) {
|
||||
t.Setenv("ORCHARD_HOME", t.TempDir())
|
||||
|
||||
devController, devWorker, err := dev.CreateDevControllerAndWorker(t.TempDir(),
|
||||
":0", nil, additionalControllerOpts, additionalWorkerOpts)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = devWorker.Close()
|
||||
})
|
||||
|
||||
devContext, cancelDevFunc := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancelDevFunc)
|
||||
|
||||
go func() {
|
||||
err := devController.Run(devContext)
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Errorf("dev controller failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := devWorker.Run(devContext)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("dev worker failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
devClient, err := client.New(client.WithAddress(devController.Address()))
|
||||
require.NoError(t, err)
|
||||
|
||||
return devClient, devController, devWorker
|
||||
}
|
||||
|
|
@ -2,14 +2,12 @@ package tests_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cirruslabs/orchard/internal/command/dev"
|
||||
"github.com/cirruslabs/orchard/internal/controller"
|
||||
"github.com/cirruslabs/orchard/internal/worker"
|
||||
"github.com/cirruslabs/orchard/internal/tests/devcontroller"
|
||||
"github.com/cirruslabs/orchard/internal/tests/wait"
|
||||
"github.com/cirruslabs/orchard/internal/worker/ondiskname"
|
||||
"github.com/cirruslabs/orchard/internal/worker/tart"
|
||||
"github.com/cirruslabs/orchard/pkg/client"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -18,7 +16,6 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/slices"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
|
@ -28,7 +25,7 @@ import (
|
|||
)
|
||||
|
||||
func TestSingleVM(t *testing.T) {
|
||||
devClient := StartIntegrationTestEnvironment(t)
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
workers, err := devClient.Workers().List(context.Background())
|
||||
if err != nil {
|
||||
|
|
@ -52,7 +49,7 @@ func TestSingleVM(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, Wait(2*time.Minute, func() bool {
|
||||
assert.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
vm, err := devClient.VMs().Get(context.Background(), "test-vm")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -66,7 +63,7 @@ func TestSingleVM(t *testing.T) {
|
|||
}
|
||||
assert.Empty(t, runningVM.StatusMessage)
|
||||
assert.Equal(t, v1.VMStatusRunning, runningVM.Status)
|
||||
assert.True(t, Wait(2*time.Minute, func() bool {
|
||||
assert.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
logLines, err := devClient.VMs().Logs(context.Background(), "test-vm")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -92,7 +89,7 @@ func TestSingleVM(t *testing.T) {
|
|||
require.NoError(t, devClient.VMs().Delete(context.Background(), "test-vm"))
|
||||
|
||||
// Ensure that the worker has deleted this VM from disk
|
||||
assert.True(t, Wait(2*time.Minute, func() bool {
|
||||
assert.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
t.Logf("Waiting for the VM to be garbage collected...")
|
||||
|
||||
return !hasVMByPredicate(t, func(info tart.VMInfo) bool {
|
||||
|
|
@ -102,7 +99,7 @@ func TestSingleVM(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFailedStartupScript(t *testing.T) {
|
||||
devClient := StartIntegrationTestEnvironment(t)
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
workers, err := devClient.Workers().List(context.Background())
|
||||
if err != nil {
|
||||
|
|
@ -125,7 +122,7 @@ func TestFailedStartupScript(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, Wait(2*time.Minute, func() bool {
|
||||
assert.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
vm, err := devClient.VMs().Get(context.Background(), "test-vm")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -141,72 +138,10 @@ func TestFailedStartupScript(t *testing.T) {
|
|||
"failed to run startup script: Process exited with status 123")
|
||||
}
|
||||
|
||||
func Wait(duration time.Duration, condition func() bool) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), duration)
|
||||
defer cancel()
|
||||
for {
|
||||
if condition() {
|
||||
// all good
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(5 * time.Second):
|
||||
// try again
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func StartIntegrationTestEnvironment(
|
||||
t *testing.T,
|
||||
) *client.Client {
|
||||
return StartIntegrationTestEnvironmentWithAdditionalOpts(t, nil, nil)
|
||||
}
|
||||
|
||||
func StartIntegrationTestEnvironmentWithAdditionalOpts(
|
||||
t *testing.T,
|
||||
additionalControllerOpts []controller.Option,
|
||||
additionalWorkerOpts []worker.Option,
|
||||
) *client.Client {
|
||||
t.Setenv("ORCHARD_HOME", t.TempDir())
|
||||
devController, devWorker, err := dev.CreateDevControllerAndWorker(t.TempDir(),
|
||||
":0", nil, additionalControllerOpts, additionalWorkerOpts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = devWorker.Close()
|
||||
})
|
||||
devContext, cancelDevFunc := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancelDevFunc)
|
||||
go func() {
|
||||
err := devController.Run(devContext)
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Errorf("dev controller failed: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
err := devWorker.Run(devContext)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("dev worker failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
devClient, err := client.New(client.WithAddress(devController.Address()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return devClient
|
||||
}
|
||||
|
||||
func TestPortForwarding(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
devClient := StartIntegrationTestEnvironment(t)
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
// Create a generic macOS VM
|
||||
err := devClient.VMs().Create(ctx, &v1.VM{
|
||||
|
|
@ -261,7 +196,7 @@ func TestPortForwarding(t *testing.T) {
|
|||
func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
devClient := StartIntegrationTestEnvironment(t)
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
const (
|
||||
dummyWorkerName = "dummy-worker"
|
||||
|
|
@ -304,7 +239,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Wait for the dummy VM to get scheduled to a dummy worker
|
||||
require.True(t, Wait(2*time.Minute, func() bool {
|
||||
require.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
vm, err := devClient.VMs().Get(context.Background(), dummyVMName)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -318,7 +253,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Wait for the scheduler to change the dummy VM's status to "failed"
|
||||
require.True(t, Wait(2*time.Minute, func() bool {
|
||||
require.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
vm, err := devClient.VMs().Get(context.Background(), dummyVMName)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -340,7 +275,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) {
|
|||
func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
devClient := StartIntegrationTestEnvironmentWithAdditionalOpts(t,
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironmentWithAdditionalOpts(t,
|
||||
[]controller.Option{controller.WithWorkerOfflineTimeout(1 * time.Minute)}, nil)
|
||||
|
||||
const (
|
||||
|
|
@ -379,7 +314,7 @@ func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Wait for the VM to be marked as failed
|
||||
assert.True(t, Wait(2*time.Minute, func() bool {
|
||||
assert.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
vm, err := devClient.VMs().Get(context.Background(), dummyVMName)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -415,10 +350,10 @@ func TestVMGarbageCollection(t *testing.T) {
|
|||
require.True(t, hasVM(t, vmName, logger))
|
||||
|
||||
// Start the Orchard Worker
|
||||
_ = StartIntegrationTestEnvironment(t)
|
||||
_, _, _ = devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
// Wait for the Orchard Worker to garbage-collect this VM
|
||||
require.True(t, Wait(2*time.Minute, func() bool {
|
||||
require.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
t.Logf("Waiting for the on-disk VM to be cleaned up by the worker")
|
||||
|
||||
return !hasVM(t, vmName, logger)
|
||||
|
|
@ -426,7 +361,7 @@ func TestVMGarbageCollection(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHostDirs(t *testing.T) {
|
||||
devClient := StartIntegrationTestEnvironment(t)
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
dirToMount := t.TempDir()
|
||||
|
||||
|
|
@ -461,7 +396,7 @@ func TestHostDirs(t *testing.T) {
|
|||
|
||||
var vm *v1.VM
|
||||
|
||||
require.True(t, Wait(2*time.Minute, func() bool {
|
||||
require.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
vm, err = devClient.VMs().Get(context.Background(), vmName)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -475,7 +410,7 @@ func TestHostDirs(t *testing.T) {
|
|||
|
||||
var logLines []string
|
||||
|
||||
require.True(t, Wait(2*time.Minute, func() bool {
|
||||
require.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
logLines, err = devClient.VMs().Logs(context.Background(), vmName)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -495,7 +430,7 @@ func TestHostDirs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHostDirsInvalidPolicy(t *testing.T) {
|
||||
devClient := StartIntegrationTestEnvironment(t)
|
||||
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
|
||||
|
||||
dirToMount := t.TempDir()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"github.com/cirruslabs/orchard/internal/controller"
|
||||
"github.com/cirruslabs/orchard/internal/tests/devcontroller"
|
||||
"github.com/cirruslabs/orchard/internal/tests/wait"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSSHServer(t *testing.T) {
|
||||
// Generate SSH host key for the Controller
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
sshPublicKey, err := ssh.NewPublicKey(publicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := ssh.NewSignerFromKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run the Controller
|
||||
devClient, devController, _ := devcontroller.StartIntegrationTestEnvironmentWithAdditionalOpts(t, []controller.Option{
|
||||
controller.WithSSHServer(":0", signer),
|
||||
}, nil)
|
||||
|
||||
// Create a VM to which we'll connect via Controller's SSH server
|
||||
err = devClient.VMs().Create(context.Background(), &v1.VM{
|
||||
Meta: v1.Meta{
|
||||
Name: "test-vm",
|
||||
},
|
||||
Image: "ghcr.io/cirruslabs/macos-sonoma-base:latest",
|
||||
CPU: 4,
|
||||
Memory: 8 * 1024,
|
||||
Headless: true,
|
||||
Status: v1.VMStatusPending,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the VM to start
|
||||
assert.True(t, wait.Wait(2*time.Minute, func() bool {
|
||||
vm, err := devClient.VMs().Get(context.Background(), "test-vm")
|
||||
require.NoError(t, err)
|
||||
|
||||
return vm.Status == v1.VMStatusRunning
|
||||
}), "failed to wait for the VM to start")
|
||||
|
||||
// Create a service account whose credentials we'll use to connect to the Controller's SSH server
|
||||
require.NoError(t, devClient.ServiceAccounts().Create(context.Background(), &v1.ServiceAccount{
|
||||
Meta: v1.Meta{
|
||||
Name: "ssh-user",
|
||||
},
|
||||
Token: "ssh-password",
|
||||
Roles: []v1.ServiceAccountRole{
|
||||
v1.ServiceAccountRoleComputeWrite,
|
||||
},
|
||||
}))
|
||||
|
||||
// Connect to the VM over Orchard Controller's SSH server
|
||||
sshAddress, ok := devController.SSHAddress()
|
||||
require.True(t, ok)
|
||||
|
||||
sshClientController, err := ssh.Dial("tcp", sshAddress, &ssh.ClientConfig{
|
||||
User: "ssh-user",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password("ssh-password"),
|
||||
},
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
if subtle.ConstantTimeCompare(sshPublicKey.Marshal(), key.Marshal()) != 1 {
|
||||
return fmt.Errorf("untrustred public key was presented by the Controller")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
netConnVM, err := sshClientController.Dial("tcp", "test-vm:22")
|
||||
require.NoError(t, err)
|
||||
|
||||
sshConnVM, sshChansVM, sshReqsVM, err := ssh.NewClientConn(netConnVM, "test-vm:22", &ssh.ClientConfig{
|
||||
User: "admin",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password("admin"),
|
||||
},
|
||||
HostKeyCallback: func(_ string, _ net.Addr, _ ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sshClientVM := ssh.NewClient(sshConnVM, sshChansVM, sshReqsVM)
|
||||
|
||||
sshSessVM, err := sshClientVM.NewSession()
|
||||
require.NoError(t, err)
|
||||
|
||||
unameBytes, err := sshSessVM.Output("uname -a")
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(unameBytes), "Darwin")
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
package wait
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Wait(duration time.Duration, predicate func() bool) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), duration)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
if predicate() {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(1 * time.Second):
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue